mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d30ceae8d | ||
|
|
60f6ed6bf6 | ||
|
|
4a2f7d4a99 | ||
|
|
c19a393be9 | ||
|
|
938ffb002e | ||
|
|
372a01290b | ||
|
|
8b163ca49b | ||
|
|
d23810dc53 | ||
|
|
62ed5422dd | ||
|
|
2e76302af7 | ||
|
|
6553828008 | ||
|
|
adcb7bf00e | ||
|
|
876e85e7ad | ||
|
|
2e7818d688 | ||
|
|
836c4dda2b | ||
|
|
e65e9587b4 | ||
|
|
aaadd6ed04 | ||
|
|
870b21916c | ||
|
|
fb119f9a67 | ||
|
|
ad54795a24 | ||
|
|
0abe322cca | ||
|
|
b071511676 | ||
|
|
7d9a757a26 | ||
|
|
bbf4024dc7 | ||
|
|
5831eb8a6a | ||
|
|
61838cdb3d | ||
|
|
50dba656fd | ||
|
|
0e2821456c | ||
|
|
f25ac3aff5 | ||
|
|
f6341b7f2b | ||
|
|
4e257512b9 | ||
|
|
e53b34f321 | ||
|
|
12ddae0184 | ||
|
|
7b9c3f165e | ||
|
|
0b8e84f942 | ||
|
|
d9e27df9af | ||
|
|
f0fabf89a1 | ||
|
|
5bbfbcdae9 | ||
|
|
eb55947ec4 | ||
|
|
5f7e5184eb | ||
|
|
008a111268 | ||
|
|
fda753278c | ||
|
|
6c469b42ed | ||
|
|
dacf3a2a6e | ||
|
|
e6add93ae3 | ||
|
|
b2273ec695 | ||
|
|
aa89777dda | ||
|
|
1e1f3c0c74 | ||
|
|
1fab9204eb | ||
|
|
dbd3e71637 | ||
|
|
974f67211b | ||
|
|
0338c83b90 | ||
|
|
c6b3de1199 | ||
|
|
f1325e9ae6 | ||
|
|
587012396b | ||
|
|
adebd941e1 | ||
|
|
bb500b7b2a | ||
|
|
cceada7dae | ||
|
|
5c2e7ae265 | ||
|
|
420bedd615 | ||
|
|
a79f6c5e1e | ||
|
|
0484c59ead | ||
|
|
7bbf621490 | ||
|
|
ef81aeb463 | ||
|
|
22414326cc | ||
|
|
14b155c66b | ||
|
|
e99b344b2b | ||
|
|
7fd94ab78b | ||
|
|
078529e51e | ||
|
|
23a4cf11c8 | ||
|
|
d1f0902ec0 | ||
|
|
ee86dbca9d | ||
|
|
733d4c2b85 | ||
|
|
406d3f3cab | ||
|
|
1ed93a5fd0 | ||
|
|
463ddea36f | ||
|
|
e769f67699 | ||
|
|
52d2ae9708 | ||
|
|
2e59998c51 | ||
|
|
32e58115cc | ||
|
|
ba27026399 | ||
|
|
c15b419c4c | ||
|
|
5bd27a5d17 | ||
|
|
0e7b8aab8c | ||
|
|
236908c03d | ||
|
|
67d028cf50 | ||
|
|
66ba487697 | ||
|
|
8c7875aa4d | ||
|
|
145171464f | ||
|
|
e5aa676853 | ||
|
|
9b4fc42457 | ||
|
|
caae7e4603 | ||
|
|
a26db8b3e2 | ||
|
|
8e81e395b3 | ||
|
|
f0e89992f7 | ||
|
|
4eaa0cf14a | ||
|
|
e9ec2280ec | ||
|
|
bb7bfb6980 | ||
|
|
b66f97c100 | ||
|
|
b51ad0d893 | ||
|
|
4eb22d8ee9 | ||
|
|
2392e7cf99 |
38
.github/workflows/backend-ci.yml
vendored
Normal file
38
.github/workflows/backend-ci.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: backend/go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: Run tests
|
||||||
|
working-directory: backend
|
||||||
|
run: go test ./...
|
||||||
|
|
||||||
|
golangci-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: backend/go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v9
|
||||||
|
with:
|
||||||
|
version: v2.7
|
||||||
|
args: --timeout=5m
|
||||||
|
working-directory: backend
|
||||||
94
.github/workflows/release.yml
vendored
94
.github/workflows/release.yml
vendored
@@ -85,6 +85,19 @@ jobs:
|
|||||||
go-version: '1.24'
|
go-version: '1.24'
|
||||||
cache-dependency-path: backend/go.sum
|
cache-dependency-path: backend/go.sum
|
||||||
|
|
||||||
|
# Docker setup for GoReleaser
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Fetch tags with annotations
|
- name: Fetch tags with annotations
|
||||||
run: |
|
run: |
|
||||||
# 确保获取完整的 annotated tag 信息
|
# 确保获取完整的 annotated tag 信息
|
||||||
@@ -117,87 +130,16 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
||||||
|
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
|
||||||
|
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
|
||||||
|
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
|
||||||
# ===========================================================================
|
# Update DockerHub description
|
||||||
# Docker Build and Push
|
|
||||||
# ===========================================================================
|
|
||||||
docker:
|
|
||||||
needs: [update-version, build-frontend]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Download VERSION artifact
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: version-file
|
|
||||||
path: backend/cmd/server/
|
|
||||||
|
|
||||||
- name: Download frontend artifact
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: frontend-dist
|
|
||||||
path: backend/internal/web/dist/
|
|
||||||
|
|
||||||
# Extract version from tag
|
|
||||||
- name: Extract version
|
|
||||||
id: version
|
|
||||||
run: |
|
|
||||||
VERSION=${GITHUB_REF#refs/tags/v}
|
|
||||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
|
||||||
echo "Version: $VERSION"
|
|
||||||
|
|
||||||
# Set up Docker Buildx for multi-platform builds
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v3
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
# Login to DockerHub
|
|
||||||
- name: Login to DockerHub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
# Extract metadata for Docker
|
|
||||||
- name: Extract Docker metadata
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
weishaw/sub2api
|
|
||||||
tags: |
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
type=semver,pattern={{major}}.{{minor}}
|
|
||||||
type=semver,pattern={{major}}
|
|
||||||
type=raw,value=latest,enable={{is_default_branch}}
|
|
||||||
|
|
||||||
# Build and push Docker image
|
|
||||||
- name: Build and push Docker image
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile
|
|
||||||
platforms: linux/amd64,linux/arm64
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
build-args: |
|
|
||||||
VERSION=${{ steps.version.outputs.version }}
|
|
||||||
COMMIT=${{ github.sha }}
|
|
||||||
DATE=${{ github.event.head_commit.timestamp }}
|
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
|
||||||
|
|
||||||
# Update DockerHub description (optional)
|
|
||||||
- name: Update DockerHub description
|
- name: Update DockerHub description
|
||||||
uses: peter-evans/dockerhub-description@v4
|
uses: peter-evans/dockerhub-description@v4
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
repository: weishaw/sub2api
|
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||||
short-description: "Sub2API - AI API Gateway Platform"
|
short-description: "Sub2API - AI API Gateway Platform"
|
||||||
readme-filepath: ./deploy/DOCKER.md
|
readme-filepath: ./deploy/DOCKER.md
|
||||||
|
|||||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -28,6 +28,7 @@ node_modules/
|
|||||||
frontend/node_modules/
|
frontend/node_modules/
|
||||||
frontend/dist/
|
frontend/dist/
|
||||||
*.local
|
*.local
|
||||||
|
*.tsbuildinfo
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
npm-debug.log*
|
npm-debug.log*
|
||||||
@@ -81,14 +82,27 @@ build/
|
|||||||
release/
|
release/
|
||||||
|
|
||||||
# 后端嵌入的前端构建产物
|
# 后端嵌入的前端构建产物
|
||||||
|
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
|
||||||
|
# while still ignoring generated frontend build outputs.
|
||||||
backend/internal/web/dist/
|
backend/internal/web/dist/
|
||||||
|
!backend/internal/web/dist/
|
||||||
|
backend/internal/web/dist/*
|
||||||
|
!backend/internal/web/dist/.keep
|
||||||
|
|
||||||
# 后端运行时缓存数据
|
# 后端运行时缓存数据
|
||||||
backend/data/
|
backend/data/
|
||||||
|
|
||||||
|
# ===================
|
||||||
|
# 本地配置文件(包含敏感信息)
|
||||||
|
# ===================
|
||||||
|
backend/config.yaml
|
||||||
|
deploy/config.yaml
|
||||||
|
backend/.installed
|
||||||
|
|
||||||
# ===================
|
# ===================
|
||||||
# 其他
|
# 其他
|
||||||
# ===================
|
# ===================
|
||||||
tests
|
tests
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
.claude
|
.claude
|
||||||
|
scripts
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ builds:
|
|||||||
dir: backend
|
dir: backend
|
||||||
main: ./cmd/server
|
main: ./cmd/server
|
||||||
binary: sub2api
|
binary: sub2api
|
||||||
|
flags:
|
||||||
|
- -tags=embed
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
goos:
|
||||||
@@ -50,10 +52,58 @@ changelog:
|
|||||||
# 禁用自动 changelog,完全使用 tag 消息
|
# 禁用自动 changelog,完全使用 tag 消息
|
||||||
disable: true
|
disable: true
|
||||||
|
|
||||||
|
# Docker images
|
||||||
|
dockers:
|
||||||
|
- id: amd64
|
||||||
|
goos: linux
|
||||||
|
goarch: amd64
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
|
dockerfile: Dockerfile.goreleaser
|
||||||
|
use: buildx
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||||
|
|
||||||
|
- id: arm64
|
||||||
|
goos: linux
|
||||||
|
goarch: arm64
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
|
dockerfile: Dockerfile.goreleaser
|
||||||
|
use: buildx
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||||
|
|
||||||
|
# Docker manifests for multi-arch support
|
||||||
|
docker_manifests:
|
||||||
|
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
|
|
||||||
|
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
|
|
||||||
|
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
|
|
||||||
|
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
|
||||||
|
image_templates:
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||||
|
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||||
|
|
||||||
release:
|
release:
|
||||||
github:
|
github:
|
||||||
owner: Wei-Shaw
|
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
|
||||||
name: sub2api
|
name: "{{ .Env.GITHUB_REPO_NAME }}"
|
||||||
draft: false
|
draft: false
|
||||||
prerelease: auto
|
prerelease: auto
|
||||||
name_template: "Sub2API {{.Version}}"
|
name_template: "Sub2API {{.Version}}"
|
||||||
@@ -71,7 +121,7 @@ release:
|
|||||||
|
|
||||||
**One-line install (Linux):**
|
**One-line install (Linux):**
|
||||||
```bash
|
```bash
|
||||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
|
||||||
```
|
```
|
||||||
|
|
||||||
**Manual download:**
|
**Manual download:**
|
||||||
@@ -79,5 +129,5 @@ release:
|
|||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
|
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
|
||||||
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
|
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)
|
||||||
|
|||||||
11
Dockerfile
11
Dockerfile
@@ -40,14 +40,15 @@ WORKDIR /app/backend
|
|||||||
COPY backend/go.mod backend/go.sum ./
|
COPY backend/go.mod backend/go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy frontend dist from previous stage
|
# Copy backend source first
|
||||||
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
|
|
||||||
|
|
||||||
# Copy backend source
|
|
||||||
COPY backend/ ./
|
COPY backend/ ./
|
||||||
|
|
||||||
# Build the binary (BuildType=release for CI builds)
|
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||||
|
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||||
|
|
||||||
|
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
|
-tags embed \
|
||||||
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||||
-o /app/sub2api \
|
-o /app/sub2api \
|
||||||
./cmd/server
|
./cmd/server
|
||||||
|
|||||||
40
Dockerfile.goreleaser
Normal file
40
Dockerfile.goreleaser
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# =============================================================================
|
||||||
|
# Sub2API Dockerfile for GoReleaser
|
||||||
|
# =============================================================================
|
||||||
|
# This Dockerfile is used by GoReleaser to build Docker images.
|
||||||
|
# It only packages the pre-built binary, no compilation needed.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
FROM alpine:3.19
|
||||||
|
|
||||||
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
|
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||||
|
|
||||||
|
# Install runtime dependencies
|
||||||
|
RUN apk add --no-cache \
|
||||||
|
ca-certificates \
|
||||||
|
tzdata \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN addgroup -g 1000 sub2api && \
|
||||||
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy pre-built binary from GoReleaser
|
||||||
|
COPY sub2api /app/sub2api
|
||||||
|
|
||||||
|
# Create data directory
|
||||||
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
|
USER sub2api
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/sub2api"]
|
||||||
31
README.md
31
README.md
@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Demo
|
||||||
|
|
||||||
|
Try Sub2API online: **https://v2.pincc.ai/**
|
||||||
|
|
||||||
|
| Email | Password |
|
||||||
|
|-------|----------|
|
||||||
|
| admin@sub2api.com | admin123 |
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||||
@@ -208,26 +216,25 @@ Build and run from source code for development or customization.
|
|||||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||||
cd sub2api
|
cd sub2api
|
||||||
|
|
||||||
# 2. Build backend
|
# 2. Build frontend
|
||||||
cd backend
|
cd frontend
|
||||||
go build -o sub2api ./cmd/server
|
|
||||||
|
|
||||||
# 3. Build frontend
|
|
||||||
cd ../frontend
|
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
# Output will be in ../backend/internal/web/dist/
|
||||||
|
|
||||||
# 4. Copy frontend build to backend (for embedding)
|
# 3. Build backend with embedded frontend
|
||||||
cp -r dist ../backend/internal/web/
|
|
||||||
|
|
||||||
# 5. Create configuration file
|
|
||||||
cd ../backend
|
cd ../backend
|
||||||
|
go build -tags embed -o sub2api ./cmd/server
|
||||||
|
|
||||||
|
# 4. Create configuration file
|
||||||
cp ../deploy/config.example.yaml ./config.yaml
|
cp ../deploy/config.example.yaml ./config.yaml
|
||||||
|
|
||||||
# 6. Edit configuration
|
# 5. Edit configuration
|
||||||
nano config.yaml
|
nano config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
|
||||||
|
|
||||||
**Key configuration in `config.yaml`:**
|
**Key configuration in `config.yaml`:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -258,7 +265,7 @@ default:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 7. Run the application
|
# 6. Run the application
|
||||||
./sub2api
|
./sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
31
README_CN.md
31
README_CN.md
@@ -16,6 +16,14 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 在线体验
|
||||||
|
|
||||||
|
体验地址:**https://v2.pincc.ai/**
|
||||||
|
|
||||||
|
| 邮箱 | 密码 |
|
||||||
|
|------|------|
|
||||||
|
| admin@sub2api.com | admin123 |
|
||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||||
@@ -208,26 +216,25 @@ docker-compose logs -f
|
|||||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||||
cd sub2api
|
cd sub2api
|
||||||
|
|
||||||
# 2. 编译后端
|
# 2. 编译前端
|
||||||
cd backend
|
cd frontend
|
||||||
go build -o sub2api ./cmd/server
|
|
||||||
|
|
||||||
# 3. 编译前端
|
|
||||||
cd ../frontend
|
|
||||||
npm install
|
npm install
|
||||||
npm run build
|
npm run build
|
||||||
|
# 构建产物输出到 ../backend/internal/web/dist/
|
||||||
|
|
||||||
# 4. 复制前端构建产物到后端(用于嵌入)
|
# 3. 编译后端(嵌入前端)
|
||||||
cp -r dist ../backend/internal/web/
|
|
||||||
|
|
||||||
# 5. 创建配置文件
|
|
||||||
cd ../backend
|
cd ../backend
|
||||||
|
go build -tags embed -o sub2api ./cmd/server
|
||||||
|
|
||||||
|
# 4. 创建配置文件
|
||||||
cp ../deploy/config.example.yaml ./config.yaml
|
cp ../deploy/config.example.yaml ./config.yaml
|
||||||
|
|
||||||
# 6. 编辑配置
|
# 5. 编辑配置
|
||||||
nano config.yaml
|
nano config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
|
||||||
|
|
||||||
**`config.yaml` 关键配置:**
|
**`config.yaml` 关键配置:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -258,7 +265,7 @@ default:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 7. 运行应用
|
# 6. 运行应用
|
||||||
./sub2api
|
./sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
594
backend/.golangci.yml
Normal file
594
backend/.golangci.yml
Normal file
@@ -0,0 +1,594 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
linters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- depguard
|
||||||
|
- errcheck
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- staticcheck
|
||||||
|
- unused
|
||||||
|
|
||||||
|
settings:
|
||||||
|
depguard:
|
||||||
|
rules:
|
||||||
|
# Enforce: service must not depend on repository.
|
||||||
|
service-no-repository:
|
||||||
|
list-mode: original
|
||||||
|
files:
|
||||||
|
- "**/internal/service/**"
|
||||||
|
deny:
|
||||||
|
- pkg: sub2api/internal/repository
|
||||||
|
desc: "service must not import repository"
|
||||||
|
handler-no-repository:
|
||||||
|
list-mode: original
|
||||||
|
files:
|
||||||
|
- "**/internal/handler/**"
|
||||||
|
deny:
|
||||||
|
- pkg: sub2api/internal/repository
|
||||||
|
desc: "handler must not import repository"
|
||||||
|
errcheck:
|
||||||
|
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-type-assertions: true
|
||||||
|
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-blank: false
|
||||||
|
# To disable the errcheck built-in exclude list.
|
||||||
|
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||||
|
# Default: false
|
||||||
|
disable-default-exclusions: true
|
||||||
|
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||||
|
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||||
|
exclude-functions:
|
||||||
|
- io/ioutil.ReadFile
|
||||||
|
- io.Copy(*bytes.Buffer)
|
||||||
|
- io.Copy(os.Stdout)
|
||||||
|
- fmt.Println
|
||||||
|
- fmt.Print
|
||||||
|
- fmt.Printf
|
||||||
|
- fmt.Fprint
|
||||||
|
- fmt.Fprintf
|
||||||
|
- fmt.Fprintln
|
||||||
|
# Display function signature instead of selector.
|
||||||
|
# Default: false
|
||||||
|
verbose: true
|
||||||
|
ineffassign:
|
||||||
|
# Check escaping variables of type error, may cause false positives.
|
||||||
|
# Default: false
|
||||||
|
check-escaping-errors: true
|
||||||
|
staticcheck:
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
|
||||||
|
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
|
||||||
|
dot-import-whitelist:
|
||||||
|
- fmt
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#initialisms
|
||||||
|
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
|
||||||
|
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
|
||||||
|
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
|
||||||
|
# Default: ["200", "400", "404", "500"]
|
||||||
|
http-status-code-whitelist: [ "200", "400", "404", "500" ]
|
||||||
|
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||||
|
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||||
|
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
|
||||||
|
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||||
|
checks:
|
||||||
|
# Invalid regular expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1000
|
||||||
|
- SA1000
|
||||||
|
# Invalid template.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1001
|
||||||
|
- SA1001
|
||||||
|
# Invalid format in 'time.Parse'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1002
|
||||||
|
- SA1002
|
||||||
|
# Unsupported argument to functions in 'encoding/binary'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1003
|
||||||
|
- SA1003
|
||||||
|
# Suspiciously small untyped constant in 'time.Sleep'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1004
|
||||||
|
- SA1004
|
||||||
|
# Invalid first argument to 'exec.Command'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1005
|
||||||
|
- SA1005
|
||||||
|
# 'Printf' with dynamic first argument and no further arguments.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1006
|
||||||
|
- SA1006
|
||||||
|
# Invalid URL in 'net/url.Parse'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1007
|
||||||
|
- SA1007
|
||||||
|
# Non-canonical key in 'http.Header' map.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1008
|
||||||
|
- SA1008
|
||||||
|
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1010
|
||||||
|
- SA1010
|
||||||
|
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1011
|
||||||
|
- SA1011
|
||||||
|
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1012
|
||||||
|
- SA1012
|
||||||
|
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1013
|
||||||
|
- SA1013
|
||||||
|
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1014
|
||||||
|
- SA1014
|
||||||
|
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1015
|
||||||
|
- SA1015
|
||||||
|
# Trapping a signal that cannot be trapped.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1016
|
||||||
|
- SA1016
|
||||||
|
# Channels used with 'os/signal.Notify' should be buffered.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1017
|
||||||
|
- SA1017
|
||||||
|
# 'strings.Replace' called with 'n == 0', which does nothing.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1018
|
||||||
|
- SA1018
|
||||||
|
# Using a deprecated function, variable, constant or field.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1019
|
||||||
|
- SA1019
|
||||||
|
# Using an invalid host:port pair with a 'net.Listen'-related function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1020
|
||||||
|
- SA1020
|
||||||
|
# Using 'bytes.Equal' to compare two 'net.IP'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1021
|
||||||
|
- SA1021
|
||||||
|
# Modifying the buffer in an 'io.Writer' implementation.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1023
|
||||||
|
- SA1023
|
||||||
|
# A string cutset contains duplicate characters.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1024
|
||||||
|
- SA1024
|
||||||
|
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1025
|
||||||
|
- SA1025
|
||||||
|
# Cannot marshal channels or functions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1026
|
||||||
|
- SA1026
|
||||||
|
# Atomic access to 64-bit variable must be 64-bit aligned.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1027
|
||||||
|
- SA1027
|
||||||
|
# 'sort.Slice' can only be used on slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1028
|
||||||
|
- SA1028
|
||||||
|
# Inappropriate key in call to 'context.WithValue'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1029
|
||||||
|
- SA1029
|
||||||
|
# Invalid argument in call to a 'strconv' function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1030
|
||||||
|
- SA1030
|
||||||
|
# Overlapping byte slices passed to an encoder.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1031
|
||||||
|
- SA1031
|
||||||
|
# Wrong order of arguments to 'errors.Is'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA1032
|
||||||
|
- SA1032
|
||||||
|
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2000
|
||||||
|
- SA2000
|
||||||
|
# Empty critical section, did you mean to defer the unlock?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2001
|
||||||
|
- SA2001
|
||||||
|
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2002
|
||||||
|
- SA2002
|
||||||
|
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA2003
|
||||||
|
- SA2003
|
||||||
|
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA3000
|
||||||
|
- SA3000
|
||||||
|
# Assigning to 'b.N' in benchmarks distorts the results.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA3001
|
||||||
|
- SA3001
|
||||||
|
# Binary operator has identical expressions on both sides.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4000
|
||||||
|
- SA4000
|
||||||
|
# '&*x' gets simplified to 'x', it does not copy 'x'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4001
|
||||||
|
- SA4001
|
||||||
|
# Comparing unsigned values against negative values is pointless.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4003
|
||||||
|
- SA4003
|
||||||
|
# The loop exits unconditionally after one iteration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4004
|
||||||
|
- SA4004
|
||||||
|
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4005
|
||||||
|
- SA4005
|
||||||
|
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4006
|
||||||
|
- SA4006
|
||||||
|
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4008
|
||||||
|
- SA4008
|
||||||
|
# A function argument is overwritten before its first use.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4009
|
||||||
|
- SA4009
|
||||||
|
# The result of 'append' will never be observed anywhere.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4010
|
||||||
|
- SA4010
|
||||||
|
# Break statement with no effect. Did you mean to break out of an outer loop?.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4011
|
||||||
|
- SA4011
|
||||||
|
# Comparing a value against NaN even though no value is equal to NaN.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4012
|
||||||
|
- SA4012
|
||||||
|
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4013
|
||||||
|
- SA4013
|
||||||
|
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4014
|
||||||
|
- SA4014
|
||||||
|
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4015
|
||||||
|
- SA4015
|
||||||
|
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4016
|
||||||
|
- SA4016
|
||||||
|
# Discarding the return values of a function without side effects, making the call pointless.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4017
|
||||||
|
- SA4017
|
||||||
|
# Self-assignment of variables.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4018
|
||||||
|
- SA4018
|
||||||
|
# Multiple, identical build constraints in the same file.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4019
|
||||||
|
- SA4019
|
||||||
|
# Unreachable case clause in a type switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4020
|
||||||
|
- SA4020
|
||||||
|
# "x = append(y)" is equivalent to "x = y".
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4021
|
||||||
|
- SA4021
|
||||||
|
# Comparing the address of a variable against nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4022
|
||||||
|
- SA4022
|
||||||
|
# Impossible comparison of interface value with untyped nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4023
|
||||||
|
- SA4023
|
||||||
|
# Checking for impossible return value from a builtin function.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4024
|
||||||
|
- SA4024
|
||||||
|
# Integer division of literals that results in zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4025
|
||||||
|
- SA4025
|
||||||
|
# Go constants cannot express negative zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4026
|
||||||
|
- SA4026
|
||||||
|
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4027
|
||||||
|
- SA4027
|
||||||
|
# 'x % 1' is always zero.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4028
|
||||||
|
- SA4028
|
||||||
|
# Ineffective attempt at sorting slice.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4029
|
||||||
|
- SA4029
|
||||||
|
# Ineffective attempt at generating random number.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4030
|
||||||
|
- SA4030
|
||||||
|
# Checking never-nil value against nil.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4031
|
||||||
|
- SA4031
|
||||||
|
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA4032
|
||||||
|
- SA4032
|
||||||
|
# Assignment to nil map.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5000
|
||||||
|
- SA5000
|
||||||
|
# Deferring 'Close' before checking for a possible error.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5001
|
||||||
|
- SA5001
|
||||||
|
# The empty for loop ("for {}") spins and can block the scheduler.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5002
|
||||||
|
- SA5002
|
||||||
|
# Defers in infinite loops will never execute.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5003
|
||||||
|
- SA5003
|
||||||
|
# "for { select { ..." with an empty default branch spins.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5004
|
||||||
|
- SA5004
|
||||||
|
# The finalizer references the finalized object, preventing garbage collection.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5005
|
||||||
|
- SA5005
|
||||||
|
# Infinite recursive call.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5007
|
||||||
|
- SA5007
|
||||||
|
# Invalid struct tag.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5008
|
||||||
|
- SA5008
|
||||||
|
# Invalid Printf call.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5009
|
||||||
|
- SA5009
|
||||||
|
# Impossible type assertion.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5010
|
||||||
|
- SA5010
|
||||||
|
# Possible nil pointer dereference.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5011
|
||||||
|
- SA5011
|
||||||
|
# Passing odd-sized slice to function expecting even size.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA5012
|
||||||
|
- SA5012
|
||||||
|
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6000
|
||||||
|
- SA6000
|
||||||
|
# Missing an optimization opportunity when indexing maps by byte slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6001
|
||||||
|
- SA6001
|
||||||
|
# Storing non-pointer values in 'sync.Pool' allocates memory.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6002
|
||||||
|
- SA6002
|
||||||
|
# Converting a string to a slice of runes before ranging over it.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6003
|
||||||
|
- SA6003
|
||||||
|
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6005
|
||||||
|
- SA6005
|
||||||
|
# Using io.WriteString to write '[]byte'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA6006
|
||||||
|
- SA6006
|
||||||
|
# Defers in range loops may not run when you expect them to.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9001
|
||||||
|
- SA9001
|
||||||
|
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9002
|
||||||
|
- SA9002
|
||||||
|
# Empty body in an if or else branch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9003
|
||||||
|
- SA9003
|
||||||
|
# Only the first constant has an explicit type.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9004
|
||||||
|
- SA9004
|
||||||
|
# Trying to marshal a struct with no public fields nor custom marshaling.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9005
|
||||||
|
- SA9005
|
||||||
|
# Dubious bit shifting of a fixed size integer value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9006
|
||||||
|
- SA9006
|
||||||
|
# Deleting a directory that shouldn't be deleted.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9007
|
||||||
|
- SA9007
|
||||||
|
# 'else' branch of a type assertion is probably not reading the right value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9008
|
||||||
|
- SA9008
|
||||||
|
# Ineffectual Go compiler directive.
|
||||||
|
# https://staticcheck.dev/docs/checks/#SA9009
|
||||||
|
- SA9009
|
||||||
|
# Incorrect or missing package comment.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1000
|
||||||
|
- ST1000
|
||||||
|
# Dot imports are discouraged.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1001
|
||||||
|
- ST1001
|
||||||
|
# Poorly chosen identifier.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1003
|
||||||
|
- ST1003
|
||||||
|
# Incorrectly formatted error string.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1005
|
||||||
|
- ST1005
|
||||||
|
# Poorly chosen receiver name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1006
|
||||||
|
- ST1006
|
||||||
|
# A function's error value should be its last return value.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1008
|
||||||
|
- ST1008
|
||||||
|
# Poorly chosen name for variable of type 'time.Duration'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1011
|
||||||
|
- ST1011
|
||||||
|
# Poorly chosen name for error variable.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1012
|
||||||
|
- ST1012
|
||||||
|
# Should use constants for HTTP error codes, not magic numbers.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1013
|
||||||
|
- ST1013
|
||||||
|
# A switch's default case should be the first or last case.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1015
|
||||||
|
- ST1015
|
||||||
|
# Use consistent method receiver names.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1016
|
||||||
|
- ST1016
|
||||||
|
# Don't use Yoda conditions.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1017
|
||||||
|
- ST1017
|
||||||
|
# Avoid zero-width and control characters in string literals.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1018
|
||||||
|
- ST1018
|
||||||
|
# Importing the same package multiple times.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1019
|
||||||
|
- ST1019
|
||||||
|
# The documentation of an exported function should start with the function's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1020
|
||||||
|
- ST1020
|
||||||
|
# The documentation of an exported type should start with type's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1021
|
||||||
|
- ST1021
|
||||||
|
# The documentation of an exported variable or constant should start with variable's name.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1022
|
||||||
|
- ST1022
|
||||||
|
# Redundant type in variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1023
|
||||||
|
- ST1023
|
||||||
|
# Use plain channel send or receive instead of single-case select.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1000
|
||||||
|
- S1000
|
||||||
|
# Replace for loop with call to copy.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1001
|
||||||
|
- S1001
|
||||||
|
# Omit comparison with boolean constant.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1002
|
||||||
|
- S1002
|
||||||
|
# Replace call to 'strings.Index' with 'strings.Contains'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1003
|
||||||
|
- S1003
|
||||||
|
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1004
|
||||||
|
- S1004
|
||||||
|
# Drop unnecessary use of the blank identifier.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1005
|
||||||
|
- S1005
|
||||||
|
# Use "for { ... }" for infinite loops.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1006
|
||||||
|
- S1006
|
||||||
|
# Simplify regular expression by using raw string literal.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1007
|
||||||
|
- S1007
|
||||||
|
# Simplify returning boolean expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1008
|
||||||
|
- S1008
|
||||||
|
# Omit redundant nil check on slices, maps, and channels.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1009
|
||||||
|
- S1009
|
||||||
|
# Omit default slice index.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1010
|
||||||
|
- S1010
|
||||||
|
# Use a single 'append' to concatenate two slices.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1011
|
||||||
|
- S1011
|
||||||
|
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1012
|
||||||
|
- S1012
|
||||||
|
# Use a type conversion instead of manually copying struct fields.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1016
|
||||||
|
- S1016
|
||||||
|
# Replace manual trimming with 'strings.TrimPrefix'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1017
|
||||||
|
- S1017
|
||||||
|
# Use "copy" for sliding elements.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1018
|
||||||
|
- S1018
|
||||||
|
# Simplify "make" call by omitting redundant arguments.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1019
|
||||||
|
- S1019
|
||||||
|
# Omit redundant nil check in type assertion.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1020
|
||||||
|
- S1020
|
||||||
|
# Merge variable declaration and assignment.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1021
|
||||||
|
- S1021
|
||||||
|
# Omit redundant control flow.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1023
|
||||||
|
- S1023
|
||||||
|
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1024
|
||||||
|
- S1024
|
||||||
|
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1025
|
||||||
|
- S1025
|
||||||
|
# Simplify error construction with 'fmt.Errorf'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1028
|
||||||
|
- S1028
|
||||||
|
# Range over the string directly.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1029
|
||||||
|
- S1029
|
||||||
|
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1030
|
||||||
|
- S1030
|
||||||
|
# Omit redundant nil check around loop.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1031
|
||||||
|
- S1031
|
||||||
|
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1032
|
||||||
|
- S1032
|
||||||
|
# Unnecessary guard around call to "delete".
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1033
|
||||||
|
- S1033
|
||||||
|
# Use result of type assertion to simplify cases.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1034
|
||||||
|
- S1034
|
||||||
|
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1035
|
||||||
|
- S1035
|
||||||
|
# Unnecessary guard around map access.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1036
|
||||||
|
- S1036
|
||||||
|
# Elaborate way of sleeping.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1037
|
||||||
|
- S1037
|
||||||
|
# Unnecessarily complex way of printing formatted string.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1038
|
||||||
|
- S1038
|
||||||
|
# Unnecessary use of 'fmt.Sprint'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1039
|
||||||
|
- S1039
|
||||||
|
# Type assertion to current type.
|
||||||
|
# https://staticcheck.dev/docs/checks/#S1040
|
||||||
|
- S1040
|
||||||
|
# Apply De Morgan's law.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1001
|
||||||
|
- QF1001
|
||||||
|
# Convert untagged switch to tagged switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1002
|
||||||
|
- QF1002
|
||||||
|
# Convert if/else-if chain to tagged switch.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1003
|
||||||
|
- QF1003
|
||||||
|
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1004
|
||||||
|
- QF1004
|
||||||
|
# Expand call to 'math.Pow'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1005
|
||||||
|
- QF1005
|
||||||
|
# Lift 'if'+'break' into loop condition.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1006
|
||||||
|
- QF1006
|
||||||
|
# Merge conditional assignment into variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1007
|
||||||
|
- QF1007
|
||||||
|
# Omit embedded fields from selector expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1008
|
||||||
|
- QF1008
|
||||||
|
# Use 'time.Time.Equal' instead of '==' operator.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1009
|
||||||
|
- QF1009
|
||||||
|
# Convert slice of bytes to string when printing it.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1010
|
||||||
|
- QF1010
|
||||||
|
# Omit redundant type from variable declaration.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1011
|
||||||
|
- QF1011
|
||||||
|
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1012
|
||||||
|
- QF1012
|
||||||
|
unused:
|
||||||
|
# Mark all struct fields that have been written to as used.
|
||||||
|
# Default: true
|
||||||
|
field-writes-are-uses: false
|
||||||
|
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
|
||||||
|
# Default: false
|
||||||
|
post-statements-are-reads: true
|
||||||
|
# Mark all exported fields as used.
|
||||||
|
# default: true
|
||||||
|
exported-fields-are-used: false
|
||||||
|
# Mark all function parameters as used.
|
||||||
|
# default: true
|
||||||
|
parameters-are-used: true
|
||||||
|
# Mark all local variables as used.
|
||||||
|
# default: true
|
||||||
|
local-variables-are-used: false
|
||||||
|
# Mark all identifiers inside generated files as used.
|
||||||
|
# Default: true
|
||||||
|
generated-is-used: false
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofmt
|
||||||
|
settings:
|
||||||
|
gofmt:
|
||||||
|
# Simplify code: gofmt with `-s` option.
|
||||||
|
# Default: true
|
||||||
|
simplify: false
|
||||||
|
# Apply the rewrite rules to the source before reformatting.
|
||||||
|
# https://pkg.go.dev/cmd/gofmt
|
||||||
|
# Default: []
|
||||||
|
rewrite-rules:
|
||||||
|
- pattern: 'interface{}'
|
||||||
|
replacement: 'any'
|
||||||
|
- pattern: 'a[b:len(a)]'
|
||||||
|
replacement: 'a[b:]'
|
||||||
16
backend/Makefile
Normal file
16
backend/Makefile
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
.PHONY: wire build build-embed
|
||||||
|
|
||||||
|
wire:
|
||||||
|
@echo "生成 Wire 代码..."
|
||||||
|
@cd cmd/server && go generate
|
||||||
|
@echo "Wire 代码生成完成"
|
||||||
|
|
||||||
|
build:
|
||||||
|
@echo "构建后端(不嵌入前端)..."
|
||||||
|
@go build -o bin/server ./cmd/server
|
||||||
|
@echo "构建完成: bin/server"
|
||||||
|
|
||||||
|
build-embed:
|
||||||
|
@echo "构建后端(嵌入前端)..."
|
||||||
|
@go build -tags embed -o bin/server ./cmd/server
|
||||||
|
@echo "构建完成: bin/server (with embedded frontend)"
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
|
//go:generate go run github.com/google/wire/cmd/wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,21 +15,13 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"sub2api/internal/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||||
"sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||||
"sub2api/internal/repository"
|
|
||||||
"sub2api/internal/service"
|
|
||||||
"sub2api/internal/setup"
|
|
||||||
"sub2api/internal/web"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
"gorm.io/driver/postgres"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed VERSION
|
//go:embed VERSION
|
||||||
@@ -100,8 +95,10 @@ func runSetupServer() {
|
|||||||
r.Use(web.ServeEmbeddedFrontend())
|
r.Use(web.ServeEmbeddedFrontend())
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := ":8080"
|
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
|
||||||
log.Printf("Setup wizard available at http://localhost%s", addr)
|
// This allows users to run setup on a different address if needed
|
||||||
|
addr := config.GetServerAddress()
|
||||||
|
log.Printf("Setup wizard available at http://%s", addr)
|
||||||
log.Println("Complete the setup wizard to configure Sub2API")
|
log.Println("Complete the setup wizard to configure Sub2API")
|
||||||
|
|
||||||
if err := r.Run(addr); err != nil {
|
if err := r.Run(addr); err != nil {
|
||||||
@@ -110,78 +107,25 @@ func runSetupServer() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runMainServer() {
|
func runMainServer() {
|
||||||
// 加载配置
|
|
||||||
cfg, err := config.Load()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化时区(类似 PHP 的 date_default_timezone_set)
|
|
||||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
|
||||||
log.Fatalf("Failed to initialize timezone: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化数据库
|
|
||||||
db, err := initDB(cfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to connect to database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化Redis
|
|
||||||
rdb := initRedis(cfg)
|
|
||||||
|
|
||||||
// 初始化Repository
|
|
||||||
repos := repository.NewRepositories(db)
|
|
||||||
|
|
||||||
// 初始化Service
|
|
||||||
services := service.NewServices(repos, rdb, cfg)
|
|
||||||
|
|
||||||
// 初始化Handler
|
|
||||||
buildInfo := handler.BuildInfo{
|
buildInfo := handler.BuildInfo{
|
||||||
Version: Version,
|
Version: Version,
|
||||||
BuildType: BuildType,
|
BuildType: BuildType,
|
||||||
}
|
}
|
||||||
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
|
|
||||||
|
|
||||||
// 设置Gin模式
|
app, err := initializeApplication(buildInfo)
|
||||||
if cfg.Server.Mode == "release" {
|
if err != nil {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
log.Fatalf("Failed to initialize application: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
// 创建路由
|
|
||||||
r := gin.New()
|
|
||||||
r.Use(gin.Recovery())
|
|
||||||
r.Use(middleware.Logger())
|
|
||||||
r.Use(middleware.CORS())
|
|
||||||
|
|
||||||
// 注册路由
|
|
||||||
registerRoutes(r, handlers, services, repos)
|
|
||||||
|
|
||||||
// Serve embedded frontend if available
|
|
||||||
if web.HasEmbeddedFrontend() {
|
|
||||||
r.Use(web.ServeEmbeddedFrontend())
|
|
||||||
}
|
}
|
||||||
|
defer app.Cleanup()
|
||||||
|
|
||||||
// 启动服务器
|
// 启动服务器
|
||||||
srv := &http.Server{
|
|
||||||
Addr: cfg.Server.Address(),
|
|
||||||
Handler: r,
|
|
||||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
|
||||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
|
||||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
|
||||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
|
||||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
|
||||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优雅关闭
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("Failed to start server: %v", err)
|
log.Fatalf("Failed to start server: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Printf("Server started on %s", cfg.Server.Address())
|
log.Printf("Server started on %s", app.Server.Addr)
|
||||||
|
|
||||||
// 等待中断信号
|
// 等待中断信号
|
||||||
quit := make(chan os.Signal, 1)
|
quit := make(chan os.Signal, 1)
|
||||||
@@ -193,289 +137,9 @@ func runMainServer() {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
if err := app.Server.Shutdown(ctx); err != nil {
|
||||||
log.Fatalf("Server forced to shutdown: %v", err)
|
log.Fatalf("Server forced to shutdown: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Server exited")
|
log.Println("Server exited")
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
|
||||||
gormConfig := &gorm.Config{}
|
|
||||||
if cfg.Server.Mode == "debug" {
|
|
||||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用带时区的 DSN 连接数据库
|
|
||||||
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
|
||||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
|
||||||
if err := model.AutoMigrate(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func initRedis(cfg *config.Config) *redis.Client {
|
|
||||||
return redis.NewClient(&redis.Options{
|
|
||||||
Addr: cfg.Redis.Address(),
|
|
||||||
Password: cfg.Redis.Password,
|
|
||||||
DB: cfg.Redis.DB,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
|
||||||
// 健康检查
|
|
||||||
r.GET("/health", func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
|
||||||
// This is used by the frontend to detect when the service has restarted after setup
|
|
||||||
r.GET("/setup/status", func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"code": 0,
|
|
||||||
"data": gin.H{
|
|
||||||
"needs_setup": false,
|
|
||||||
"step": "completed",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// API v1
|
|
||||||
v1 := r.Group("/api/v1")
|
|
||||||
{
|
|
||||||
// 公开接口
|
|
||||||
auth := v1.Group("/auth")
|
|
||||||
{
|
|
||||||
auth.POST("/register", h.Auth.Register)
|
|
||||||
auth.POST("/login", h.Auth.Login)
|
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 公开设置(无需认证)
|
|
||||||
settings := v1.Group("/settings")
|
|
||||||
{
|
|
||||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 需要认证的接口
|
|
||||||
authenticated := v1.Group("")
|
|
||||||
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
|
||||||
{
|
|
||||||
// 当前用户信息
|
|
||||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
|
||||||
|
|
||||||
// 用户接口
|
|
||||||
user := authenticated.Group("/user")
|
|
||||||
{
|
|
||||||
user.GET("/profile", h.User.GetProfile)
|
|
||||||
user.PUT("/password", h.User.ChangePassword)
|
|
||||||
}
|
|
||||||
|
|
||||||
// API Key管理
|
|
||||||
keys := authenticated.Group("/keys")
|
|
||||||
{
|
|
||||||
keys.GET("", h.APIKey.List)
|
|
||||||
keys.GET("/:id", h.APIKey.GetByID)
|
|
||||||
keys.POST("", h.APIKey.Create)
|
|
||||||
keys.PUT("/:id", h.APIKey.Update)
|
|
||||||
keys.DELETE("/:id", h.APIKey.Delete)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户可用分组(非管理员接口)
|
|
||||||
groups := authenticated.Group("/groups")
|
|
||||||
{
|
|
||||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用记录
|
|
||||||
usage := authenticated.Group("/usage")
|
|
||||||
{
|
|
||||||
usage.GET("", h.Usage.List)
|
|
||||||
usage.GET("/:id", h.Usage.GetByID)
|
|
||||||
usage.GET("/stats", h.Usage.Stats)
|
|
||||||
// User dashboard endpoints
|
|
||||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
|
||||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
|
||||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
|
||||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 卡密兑换
|
|
||||||
redeem := authenticated.Group("/redeem")
|
|
||||||
{
|
|
||||||
redeem.POST("", h.Redeem.Redeem)
|
|
||||||
redeem.GET("/history", h.Redeem.GetHistory)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户订阅
|
|
||||||
subscriptions := authenticated.Group("/subscriptions")
|
|
||||||
{
|
|
||||||
subscriptions.GET("", h.Subscription.List)
|
|
||||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
|
||||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
|
||||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 管理员接口
|
|
||||||
admin := v1.Group("/admin")
|
|
||||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
|
||||||
{
|
|
||||||
// 仪表盘
|
|
||||||
dashboard := admin.Group("/dashboard")
|
|
||||||
{
|
|
||||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
|
||||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
|
||||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
|
||||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
|
||||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
|
||||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户管理
|
|
||||||
users := admin.Group("/users")
|
|
||||||
{
|
|
||||||
users.GET("", h.Admin.User.List)
|
|
||||||
users.GET("/:id", h.Admin.User.GetByID)
|
|
||||||
users.POST("", h.Admin.User.Create)
|
|
||||||
users.PUT("/:id", h.Admin.User.Update)
|
|
||||||
users.DELETE("/:id", h.Admin.User.Delete)
|
|
||||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
|
||||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
|
||||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分组管理
|
|
||||||
groups := admin.Group("/groups")
|
|
||||||
{
|
|
||||||
groups.GET("", h.Admin.Group.List)
|
|
||||||
groups.GET("/all", h.Admin.Group.GetAll)
|
|
||||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
|
||||||
groups.POST("", h.Admin.Group.Create)
|
|
||||||
groups.PUT("/:id", h.Admin.Group.Update)
|
|
||||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
|
||||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
|
||||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 账号管理
|
|
||||||
accounts := admin.Group("/accounts")
|
|
||||||
{
|
|
||||||
accounts.GET("", h.Admin.Account.List)
|
|
||||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
|
||||||
accounts.POST("", h.Admin.Account.Create)
|
|
||||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
|
||||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
|
||||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
|
||||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
|
||||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
|
||||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
|
||||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
|
||||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
|
||||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
|
||||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
|
||||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
|
||||||
|
|
||||||
// OAuth routes
|
|
||||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
|
||||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
|
||||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
|
||||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
|
||||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
|
||||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 代理管理
|
|
||||||
proxies := admin.Group("/proxies")
|
|
||||||
{
|
|
||||||
proxies.GET("", h.Admin.Proxy.List)
|
|
||||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
|
||||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
|
||||||
proxies.POST("", h.Admin.Proxy.Create)
|
|
||||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
|
||||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
|
||||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
|
||||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
|
||||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
|
||||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 卡密管理
|
|
||||||
codes := admin.Group("/redeem-codes")
|
|
||||||
{
|
|
||||||
codes.GET("", h.Admin.Redeem.List)
|
|
||||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
|
||||||
codes.GET("/export", h.Admin.Redeem.Export)
|
|
||||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
|
||||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
|
||||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
|
||||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
|
||||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 系统设置
|
|
||||||
adminSettings := admin.Group("/settings")
|
|
||||||
{
|
|
||||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
|
||||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
|
||||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
|
||||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 系统管理
|
|
||||||
system := admin.Group("/system")
|
|
||||||
{
|
|
||||||
system.GET("/version", h.Admin.System.GetVersion)
|
|
||||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
|
||||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
|
||||||
system.POST("/rollback", h.Admin.System.Rollback)
|
|
||||||
system.POST("/restart", h.Admin.System.RestartService)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 订阅管理
|
|
||||||
subscriptions := admin.Group("/subscriptions")
|
|
||||||
{
|
|
||||||
subscriptions.GET("", h.Admin.Subscription.List)
|
|
||||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
|
||||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
|
||||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
|
||||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
|
||||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
|
||||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分组下的订阅列表
|
|
||||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
|
||||||
|
|
||||||
// 用户下的订阅列表
|
|
||||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
|
||||||
|
|
||||||
// 使用记录管理
|
|
||||||
usage := admin.Group("/usage")
|
|
||||||
{
|
|
||||||
usage.GET("", h.Admin.Usage.List)
|
|
||||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
|
||||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
|
||||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// API网关(Claude API兼容)
|
|
||||||
gateway := r.Group("/v1")
|
|
||||||
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
|
||||||
{
|
|
||||||
gateway.POST("/messages", h.Gateway.Messages)
|
|
||||||
gateway.GET("/models", h.Gateway.Models)
|
|
||||||
gateway.GET("/usage", h.Gateway.Usage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
125
backend/cmd/server/wire.go
Normal file
125
backend/cmd/server/wire.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
//go:build wireinject
|
||||||
|
// +build wireinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/wire"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Application struct {
|
||||||
|
Server *http.Server
|
||||||
|
Cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||||
|
wire.Build(
|
||||||
|
// 基础设施层 ProviderSets
|
||||||
|
config.ProviderSet,
|
||||||
|
infrastructure.ProviderSet,
|
||||||
|
|
||||||
|
// 业务层 ProviderSets
|
||||||
|
repository.ProviderSet,
|
||||||
|
service.ProviderSet,
|
||||||
|
handler.ProviderSet,
|
||||||
|
|
||||||
|
// 服务器层 ProviderSet
|
||||||
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// BuildInfo provider
|
||||||
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
|
// 清理函数提供者
|
||||||
|
provideCleanup,
|
||||||
|
|
||||||
|
// 应用程序结构体
|
||||||
|
wire.Struct(new(Application), "Server", "Cleanup"),
|
||||||
|
)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
|
return service.BuildInfo{
|
||||||
|
Version: buildInfo.Version,
|
||||||
|
BuildType: buildInfo.BuildType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideCleanup(
|
||||||
|
db *gorm.DB,
|
||||||
|
rdb *redis.Client,
|
||||||
|
services *service.Services,
|
||||||
|
) func() {
|
||||||
|
return func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Cleanup steps in reverse dependency order
|
||||||
|
cleanupSteps := []struct {
|
||||||
|
name string
|
||||||
|
fn func() error
|
||||||
|
}{
|
||||||
|
{"TokenRefreshService", func() error {
|
||||||
|
services.TokenRefresh.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"PricingService", func() error {
|
||||||
|
services.Pricing.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"EmailQueueService", func() error {
|
||||||
|
services.EmailQueue.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OAuthService", func() error {
|
||||||
|
services.OAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpenAIOAuthService", func() error {
|
||||||
|
services.OpenAIOAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"Redis", func() error {
|
||||||
|
return rdb.Close()
|
||||||
|
}},
|
||||||
|
{"Database", func() error {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sqlDB.Close()
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, step := range cleanupSteps {
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
// Continue with remaining cleanup steps even if one fails
|
||||||
|
} else {
|
||||||
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if context timed out
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||||
|
default:
|
||||||
|
log.Printf("[Cleanup] All cleanup steps completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
249
backend/cmd/server/wire_gen.go
Normal file
249
backend/cmd/server/wire_gen.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
// Code generated by Wire. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
|
||||||
|
//go:build !wireinject
|
||||||
|
// +build !wireinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Injectors from wire.go:
|
||||||
|
|
||||||
|
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||||
|
configConfig, err := config.ProvideConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db, err := infrastructure.ProvideDB(configConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userRepository := repository.NewUserRepository(db)
|
||||||
|
settingRepository := repository.NewSettingRepository(db)
|
||||||
|
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||||
|
client := infrastructure.ProvideRedis(configConfig)
|
||||||
|
emailCache := repository.NewEmailCache(client)
|
||||||
|
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||||
|
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||||
|
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||||
|
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||||
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||||
|
authHandler := handler.NewAuthHandler(authService)
|
||||||
|
userService := service.NewUserService(userRepository)
|
||||||
|
userHandler := handler.NewUserHandler(userService)
|
||||||
|
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||||
|
groupRepository := repository.NewGroupRepository(db)
|
||||||
|
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
||||||
|
apiKeyCache := repository.NewApiKeyCache(client)
|
||||||
|
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
|
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||||
|
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||||
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
|
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||||
|
billingCache := repository.NewBillingCache(client)
|
||||||
|
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||||
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||||
|
redeemCache := repository.NewRedeemCache(client)
|
||||||
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||||
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
dashboardService := service.NewDashboardService(usageLogRepository)
|
||||||
|
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
||||||
|
accountRepository := repository.NewAccountRepository(db)
|
||||||
|
proxyRepository := repository.NewProxyRepository(db)
|
||||||
|
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||||
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
|
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||||
|
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||||
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||||
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
|
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||||
|
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||||
|
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||||
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
|
||||||
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||||
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
|
proxyHandler := admin.NewProxyHandler(adminService)
|
||||||
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
|
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||||
|
updateCache := repository.NewUpdateCache(client)
|
||||||
|
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||||
|
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||||
|
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||||
|
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||||
|
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||||
|
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||||
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||||
|
gatewayCache := repository.NewGatewayCache(client)
|
||||||
|
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||||
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
|
identityCache := repository.NewIdentityCache(client)
|
||||||
|
identityService := service.NewIdentityService(identityCache)
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
|
||||||
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||||
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
|
||||||
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||||
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||||
|
groupService := service.NewGroupService(groupRepository)
|
||||||
|
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||||
|
proxyService := service.NewProxyService(proxyRepository)
|
||||||
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||||
|
services := &service.Services{
|
||||||
|
Auth: authService,
|
||||||
|
User: userService,
|
||||||
|
ApiKey: apiKeyService,
|
||||||
|
Group: groupService,
|
||||||
|
Account: accountService,
|
||||||
|
Proxy: proxyService,
|
||||||
|
Redeem: redeemService,
|
||||||
|
Usage: usageService,
|
||||||
|
Pricing: pricingService,
|
||||||
|
Billing: billingService,
|
||||||
|
BillingCache: billingCacheService,
|
||||||
|
Admin: adminService,
|
||||||
|
Gateway: gatewayService,
|
||||||
|
OpenAIGateway: openAIGatewayService,
|
||||||
|
OAuth: oAuthService,
|
||||||
|
OpenAIOAuth: openAIOAuthService,
|
||||||
|
RateLimit: rateLimitService,
|
||||||
|
AccountUsage: accountUsageService,
|
||||||
|
AccountTest: accountTestService,
|
||||||
|
Setting: settingService,
|
||||||
|
Email: emailService,
|
||||||
|
EmailQueue: emailQueueService,
|
||||||
|
Turnstile: turnstileService,
|
||||||
|
Subscription: subscriptionService,
|
||||||
|
Concurrency: concurrencyService,
|
||||||
|
Identity: identityService,
|
||||||
|
Update: updateService,
|
||||||
|
TokenRefresh: tokenRefreshService,
|
||||||
|
}
|
||||||
|
repositories := &repository.Repositories{
|
||||||
|
User: userRepository,
|
||||||
|
ApiKey: apiKeyRepository,
|
||||||
|
Group: groupRepository,
|
||||||
|
Account: accountRepository,
|
||||||
|
Proxy: proxyRepository,
|
||||||
|
RedeemCode: redeemCodeRepository,
|
||||||
|
UsageLog: usageLogRepository,
|
||||||
|
Setting: settingRepository,
|
||||||
|
UserSubscription: userSubscriptionRepository,
|
||||||
|
}
|
||||||
|
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
|
||||||
|
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||||
|
v := provideCleanup(db, client, services)
|
||||||
|
application := &Application{
|
||||||
|
Server: httpServer,
|
||||||
|
Cleanup: v,
|
||||||
|
}
|
||||||
|
return application, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wire.go:
|
||||||
|
|
||||||
|
type Application struct {
|
||||||
|
Server *http.Server
|
||||||
|
Cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
|
return service.BuildInfo{
|
||||||
|
Version: buildInfo.Version,
|
||||||
|
BuildType: buildInfo.BuildType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideCleanup(
|
||||||
|
db *gorm.DB,
|
||||||
|
rdb *redis.Client,
|
||||||
|
services *service.Services,
|
||||||
|
) func() {
|
||||||
|
return func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cleanupSteps := []struct {
|
||||||
|
name string
|
||||||
|
fn func() error
|
||||||
|
}{
|
||||||
|
{"TokenRefreshService", func() error {
|
||||||
|
services.TokenRefresh.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"PricingService", func() error {
|
||||||
|
services.Pricing.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"EmailQueueService", func() error {
|
||||||
|
services.EmailQueue.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OAuthService", func() error {
|
||||||
|
services.OAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpenAIOAuthService", func() error {
|
||||||
|
services.OpenAIOAuth.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"Redis", func() error {
|
||||||
|
return rdb.Close()
|
||||||
|
}},
|
||||||
|
{"Database", func() error {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sqlDB.Close()
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, step := range cleanupSteps {
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||||
|
default:
|
||||||
|
log.Printf("[Cleanup] All cleanup steps completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
server:
|
|
||||||
host: "0.0.0.0"
|
|
||||||
port: 8080
|
|
||||||
mode: "debug" # debug/release
|
|
||||||
|
|
||||||
database:
|
|
||||||
host: "127.0.0.1"
|
|
||||||
port: 5432
|
|
||||||
user: "postgres"
|
|
||||||
password: "XZeRr7nkjHWhm8fw"
|
|
||||||
dbname: "sub2api"
|
|
||||||
sslmode: "disable"
|
|
||||||
|
|
||||||
redis:
|
|
||||||
host: "127.0.0.1"
|
|
||||||
port: 6379
|
|
||||||
password: ""
|
|
||||||
db: 0
|
|
||||||
|
|
||||||
jwt:
|
|
||||||
secret: "your-secret-key-change-in-production"
|
|
||||||
expire_hour: 24
|
|
||||||
|
|
||||||
default:
|
|
||||||
admin_email: "admin@sub2api.com"
|
|
||||||
admin_password: "admin123"
|
|
||||||
user_concurrency: 5
|
|
||||||
user_balance: 0
|
|
||||||
api_key_prefix: "sk-"
|
|
||||||
rate_multiplier: 1.0
|
|
||||||
|
|
||||||
# Timezone configuration (similar to PHP's date_default_timezone_set)
|
|
||||||
# This affects ALL time operations:
|
|
||||||
# - Database timestamps
|
|
||||||
# - Usage statistics "today" boundary
|
|
||||||
# - Subscription expiry times
|
|
||||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
|
||||||
timezone: "Asia/Shanghai"
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
module sub2api
|
module github.com/Wei-Shaw/sub2api
|
||||||
|
|
||||||
go 1.24.0
|
go 1.24.0
|
||||||
|
|
||||||
@@ -8,11 +8,15 @@ require (
|
|||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/google/wire v0.7.0
|
||||||
github.com/imroc/req/v3 v3.56.0
|
github.com/imroc/req/v3 v3.56.0
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/redis/go-redis/v9 v9.3.0
|
github.com/redis/go-redis/v9 v9.3.0
|
||||||
github.com/spf13/viper v1.18.2
|
github.com/spf13/viper v1.18.2
|
||||||
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
golang.org/x/crypto v0.44.0
|
golang.org/x/crypto v0.44.0
|
||||||
|
golang.org/x/net v0.47.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.37.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/postgres v1.5.4
|
gorm.io/driver/postgres v1.5.4
|
||||||
@@ -33,6 +37,7 @@ require (
|
|||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
|
github.com/google/subcommands v1.2.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/icholy/digest v1.1.0 // indirect
|
github.com/icholy/digest v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
@@ -50,6 +55,7 @@ require (
|
|||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/quic-go/qpack v0.5.1 // indirect
|
github.com/quic-go/qpack v0.5.1 // indirect
|
||||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||||
@@ -60,15 +66,19 @@ require (
|
|||||||
github.com/spf13/cast v1.6.0 // indirect
|
github.com/spf13/cast v1.6.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||||
golang.org/x/net v0.47.0 // indirect
|
golang.org/x/mod v0.29.0 // indirect
|
||||||
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
golang.org/x/tools v0.38.0 // indirect
|
||||||
google.golang.org/protobuf v1.31.0 // indirect
|
google.golang.org/protobuf v1.31.0 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -48,8 +48,12 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
|
|||||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||||
|
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
|
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||||
@@ -135,6 +139,15 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
|
|||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||||
@@ -154,8 +167,12 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
|||||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||||
|
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||||
|
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
@@ -166,6 +183,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
|||||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
|
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||||
|
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||||
|
|||||||
@@ -8,15 +8,30 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRefreshConfig OAuth token自动刷新配置
|
||||||
|
type TokenRefreshConfig struct {
|
||||||
|
// 是否启用自动刷新
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// 检查间隔(分钟)
|
||||||
|
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
|
||||||
|
// 提前刷新时间(小时),在token过期前多久开始刷新
|
||||||
|
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
|
||||||
|
// 最大重试次数
|
||||||
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
|
// 重试退避基础时间(秒)
|
||||||
|
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingConfig struct {
|
type PricingConfig struct {
|
||||||
@@ -37,7 +52,7 @@ type PricingConfig struct {
|
|||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
Mode string `mapstructure:"mode"` // debug/release
|
Mode string `mapstructure:"mode"` // debug/release
|
||||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||||
}
|
}
|
||||||
@@ -148,7 +163,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("server.port", 8080)
|
viper.SetDefault("server.port", 8080)
|
||||||
viper.SetDefault("server.mode", "debug")
|
viper.SetDefault("server.mode", "debug")
|
||||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||||
|
|
||||||
// Database
|
// Database
|
||||||
viper.SetDefault("database.host", "localhost")
|
viper.SetDefault("database.host", "localhost")
|
||||||
@@ -192,6 +207,13 @@ func setDefaults() {
|
|||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
|
|
||||||
|
// TokenRefresh
|
||||||
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||||
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -203,3 +225,29 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServerAddress returns the server address (host:port) from config file or environment variable.
|
||||||
|
// This is a lightweight function that can be used before full config validation,
|
||||||
|
// such as during setup wizard startup.
|
||||||
|
// Priority: config.yaml > environment variables > defaults
|
||||||
|
func GetServerAddress() string {
|
||||||
|
v := viper.New()
|
||||||
|
v.SetConfigName("config")
|
||||||
|
v.SetConfigType("yaml")
|
||||||
|
v.AddConfigPath(".")
|
||||||
|
v.AddConfigPath("./config")
|
||||||
|
v.AddConfigPath("/etc/sub2api")
|
||||||
|
|
||||||
|
// Support SERVER_HOST and SERVER_PORT environment variables
|
||||||
|
v.AutomaticEnv()
|
||||||
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
|
v.SetDefault("server.host", "0.0.0.0")
|
||||||
|
v.SetDefault("server.port", 8080)
|
||||||
|
|
||||||
|
// Try to read config file (ignore errors if not found)
|
||||||
|
_ = v.ReadInConfig()
|
||||||
|
|
||||||
|
host := v.GetString("server.host")
|
||||||
|
port := v.GetInt("server.port")
|
||||||
|
return fmt.Sprintf("%s:%d", host, port)
|
||||||
|
}
|
||||||
|
|||||||
13
backend/internal/config/wire.go
Normal file
13
backend/internal/config/wire.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "github.com/google/wire"
|
||||||
|
|
||||||
|
// ProviderSet 提供配置层的依赖
|
||||||
|
var ProviderSet = wire.NewSet(
|
||||||
|
ProvideConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProvideConfig 提供应用配置
|
||||||
|
func ProvideConfig() (*Config, error) {
|
||||||
|
return Load()
|
||||||
|
}
|
||||||
@@ -3,8 +3,12 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -12,14 +16,12 @@ import (
|
|||||||
// OAuthHandler handles OAuth-related operations for accounts
|
// OAuthHandler handles OAuth-related operations for accounts
|
||||||
type OAuthHandler struct {
|
type OAuthHandler struct {
|
||||||
oauthService *service.OAuthService
|
oauthService *service.OAuthService
|
||||||
adminService service.AdminService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthHandler creates a new OAuth handler
|
// NewOAuthHandler creates a new OAuth handler
|
||||||
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
|
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||||
return &OAuthHandler{
|
return &OAuthHandler{
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
adminService: adminService,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,47 +29,81 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
|
|||||||
type AccountHandler struct {
|
type AccountHandler struct {
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
oauthService *service.OAuthService
|
oauthService *service.OAuthService
|
||||||
|
openaiOAuthService *service.OpenAIOAuthService
|
||||||
rateLimitService *service.RateLimitService
|
rateLimitService *service.RateLimitService
|
||||||
accountUsageService *service.AccountUsageService
|
accountUsageService *service.AccountUsageService
|
||||||
accountTestService *service.AccountTestService
|
accountTestService *service.AccountTestService
|
||||||
|
concurrencyService *service.ConcurrencyService
|
||||||
|
crsSyncService *service.CRSSyncService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountHandler creates a new admin account handler
|
// NewAccountHandler creates a new admin account handler
|
||||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
|
func NewAccountHandler(
|
||||||
|
adminService service.AdminService,
|
||||||
|
oauthService *service.OAuthService,
|
||||||
|
openaiOAuthService *service.OpenAIOAuthService,
|
||||||
|
rateLimitService *service.RateLimitService,
|
||||||
|
accountUsageService *service.AccountUsageService,
|
||||||
|
accountTestService *service.AccountTestService,
|
||||||
|
concurrencyService *service.ConcurrencyService,
|
||||||
|
crsSyncService *service.CRSSyncService,
|
||||||
|
) *AccountHandler {
|
||||||
return &AccountHandler{
|
return &AccountHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
|
openaiOAuthService: openaiOAuthService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
accountUsageService: accountUsageService,
|
accountUsageService: accountUsageService,
|
||||||
accountTestService: accountTestService,
|
accountTestService: accountTestService,
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
|
crsSyncService: crsSyncService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountRequest represents create account request
|
// CreateAccountRequest represents create account request
|
||||||
type CreateAccountRequest struct {
|
type CreateAccountRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
|
||||||
Credentials map[string]interface{} `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest represents update account request
|
// UpdateAccountRequest represents update account request
|
||||||
// 使用指针类型来区分"未提供"和"设置为0"
|
// 使用指针类型来区分"未提供"和"设置为0"
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
|
||||||
Credentials map[string]interface{} `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||||
|
type BulkUpdateAccountsRequest struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
Concurrency *int `json:"concurrency"`
|
||||||
|
Priority *int `json:"priority"`
|
||||||
|
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||||
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
|
Credentials map[string]any `json:"credentials"`
|
||||||
|
Extra map[string]any `json:"extra"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||||
|
type AccountWithConcurrency struct {
|
||||||
|
*model.Account
|
||||||
|
CurrentConcurrency int `json:"current_concurrency"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all accounts with pagination
|
// List handles listing all accounts with pagination
|
||||||
@@ -85,7 +121,28 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Paginated(c, accounts, total, page, pageSize)
|
// Get current concurrency counts for all accounts
|
||||||
|
accountIDs := make([]int64, len(accounts))
|
||||||
|
for i, acc := range accounts {
|
||||||
|
accountIDs[i] = acc.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||||
|
if err != nil {
|
||||||
|
// Log error but don't fail the request, just use 0 for all
|
||||||
|
concurrencyCounts = make(map[int64]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build response with concurrency info
|
||||||
|
result := make([]AccountWithConcurrency, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
result[i] = AccountWithConcurrency{
|
||||||
|
Account: &accounts[i],
|
||||||
|
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID handles getting an account by ID
|
// GetByID handles getting an account by ID
|
||||||
@@ -186,6 +243,18 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Account deleted successfully"})
|
response.Success(c, gin.H{"message": "Account deleted successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAccountRequest represents the request body for testing an account
|
||||||
|
type TestAccountRequest struct {
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncFromCRSRequest struct {
|
||||||
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
|
Username string `json:"username" binding:"required"`
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
SyncProxies *bool `json:"sync_proxies"`
|
||||||
|
}
|
||||||
|
|
||||||
// Test handles testing account connectivity with SSE streaming
|
// Test handles testing account connectivity with SSE streaming
|
||||||
// POST /api/v1/admin/accounts/:id/test
|
// POST /api/v1/admin/accounts/:id/test
|
||||||
func (h *AccountHandler) Test(c *gin.Context) {
|
func (h *AccountHandler) Test(c *gin.Context) {
|
||||||
@@ -195,13 +264,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var req TestAccountRequest
|
||||||
|
// Allow empty body, model_id is optional
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
// Use AccountTestService to test the account with SSE streaming
|
// Use AccountTestService to test the account with SSE streaming
|
||||||
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
|
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||||
// Error already sent via SSE, just log
|
// Error already sent via SSE, just log
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||||
|
// POST /api/v1/admin/accounts/sync/crs
|
||||||
|
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||||
|
var req SyncFromCRSRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to syncing proxies (can be disabled by explicitly setting false)
|
||||||
|
syncProxies := true
|
||||||
|
if req.SyncProxies != nil {
|
||||||
|
syncProxies = *req.SyncProxies
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
Username: req.Username,
|
||||||
|
Password: req.Password,
|
||||||
|
SyncProxies: syncProxies,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Sync failed: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Refresh handles refreshing account credentials
|
// Refresh handles refreshing account credentials
|
||||||
// POST /api/v1/admin/accounts/:id/refresh
|
// POST /api/v1/admin/accounts/:id/refresh
|
||||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||||
@@ -224,21 +326,46 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use OAuth service to refresh token
|
var newCredentials map[string]any
|
||||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
|
||||||
if err != nil {
|
|
||||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update account credentials
|
if account.IsOpenAI() {
|
||||||
newCredentials := map[string]interface{}{
|
// Use OpenAI OAuth service to refresh token
|
||||||
"access_token": tokenInfo.AccessToken,
|
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||||
"token_type": tokenInfo.TokenType,
|
if err != nil {
|
||||||
"expires_in": tokenInfo.ExpiresIn,
|
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||||
"expires_at": tokenInfo.ExpiresAt,
|
return
|
||||||
"refresh_token": tokenInfo.RefreshToken,
|
}
|
||||||
"scope": tokenInfo.Scope,
|
|
||||||
|
// Build new credentials from token info
|
||||||
|
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
// Preserve non-token settings from existing credentials
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use Anthropic/Claude OAuth service to refresh token
|
||||||
|
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||||
|
newCredentials = make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update token-related fields
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||||
|
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
newCredentials["scope"] = tokenInfo.Scope
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||||
@@ -261,15 +388,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return mock data for now
|
// Parse days parameter (default 30)
|
||||||
_ = accountID
|
days := 30
|
||||||
response.Success(c, gin.H{
|
if daysStr := c.Query("days"); daysStr != "" {
|
||||||
"total_requests": 0,
|
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
|
||||||
"successful_requests": 0,
|
days = d
|
||||||
"failed_requests": 0,
|
}
|
||||||
"total_tokens": 0,
|
}
|
||||||
"average_response_time": 0,
|
|
||||||
})
|
// Calculate time range
|
||||||
|
now := timezone.Now()
|
||||||
|
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||||
|
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
|
||||||
|
|
||||||
|
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to get account stats: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearError handles clearing account error
|
// ClearError handles clearing account error
|
||||||
@@ -309,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchUpdateCredentialsRequest represents batch credentials update request
|
||||||
|
type BatchUpdateCredentialsRequest struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||||
|
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
|
||||||
|
Value any `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateCredentials handles batch updating credentials fields
|
||||||
|
// POST /api/v1/admin/accounts/batch-update-credentials
|
||||||
|
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||||
|
var req BatchUpdateCredentialsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate value type based on field
|
||||||
|
if req.Field == "intercept_warmup_requests" {
|
||||||
|
// Must be boolean
|
||||||
|
if _, ok := req.Value.(bool); !ok {
|
||||||
|
response.BadRequest(c, "intercept_warmup_requests must be boolean")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// account_uuid and org_uuid can be string or null
|
||||||
|
if req.Value != nil {
|
||||||
|
if _, ok := req.Value.(string); !ok {
|
||||||
|
response.BadRequest(c, req.Field+" must be string or null")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
success := 0
|
||||||
|
failed := 0
|
||||||
|
results := []gin.H{}
|
||||||
|
|
||||||
|
for _, accountID := range req.AccountIDs {
|
||||||
|
// Get account
|
||||||
|
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
failed++
|
||||||
|
results = append(results, gin.H{
|
||||||
|
"account_id": accountID,
|
||||||
|
"success": false,
|
||||||
|
"error": "Account not found",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update credentials field
|
||||||
|
if account.Credentials == nil {
|
||||||
|
account.Credentials = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Credentials[req.Field] = req.Value
|
||||||
|
|
||||||
|
// Update account
|
||||||
|
updateInput := &service.UpdateAccountInput{
|
||||||
|
Credentials: account.Credentials,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||||
|
if err != nil {
|
||||||
|
failed++
|
||||||
|
results = append(results, gin.H{
|
||||||
|
"account_id": accountID,
|
||||||
|
"success": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
success++
|
||||||
|
results = append(results, gin.H{
|
||||||
|
"account_id": accountID,
|
||||||
|
"success": true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"success": success,
|
||||||
|
"failed": failed,
|
||||||
|
"results": results,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
|
||||||
|
// POST /api/v1/admin/accounts/bulk-update
|
||||||
|
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||||
|
var req BulkUpdateAccountsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hasUpdates := req.Name != "" ||
|
||||||
|
req.ProxyID != nil ||
|
||||||
|
req.Concurrency != nil ||
|
||||||
|
req.Priority != nil ||
|
||||||
|
req.Status != "" ||
|
||||||
|
req.GroupIDs != nil ||
|
||||||
|
len(req.Credentials) > 0 ||
|
||||||
|
len(req.Extra) > 0
|
||||||
|
|
||||||
|
if !hasUpdates {
|
||||||
|
response.BadRequest(c, "No updates provided")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||||
|
AccountIDs: req.AccountIDs,
|
||||||
|
Name: req.Name,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
Concurrency: req.Concurrency,
|
||||||
|
Priority: req.Priority,
|
||||||
|
Status: req.Status,
|
||||||
|
GroupIDs: req.GroupIDs,
|
||||||
|
Credentials: req.Credentials,
|
||||||
|
Extra: req.Extra,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// ========== OAuth Handlers ==========
|
// ========== OAuth Handlers ==========
|
||||||
|
|
||||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||||
@@ -535,3 +803,98 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, account)
|
response.Success(c, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailableModels handles getting available models for an account
|
||||||
|
// GET /api/v1/admin/accounts/:id/models
|
||||||
|
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid account ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "Account not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle OpenAI accounts
|
||||||
|
if account.IsOpenAI() {
|
||||||
|
// For OAuth accounts: return default OpenAI models
|
||||||
|
if account.IsOAuth() {
|
||||||
|
response.Success(c, openai.DefaultModels)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For API Key accounts: check model_mapping
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if len(mapping) == 0 {
|
||||||
|
response.Success(c, openai.DefaultModels)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return mapped models
|
||||||
|
var models []openai.Model
|
||||||
|
for requestedModel := range mapping {
|
||||||
|
var found bool
|
||||||
|
for _, dm := range openai.DefaultModels {
|
||||||
|
if dm.ID == requestedModel {
|
||||||
|
models = append(models, dm)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
models = append(models, openai.Model{
|
||||||
|
ID: requestedModel,
|
||||||
|
Object: "model",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: requestedModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.Success(c, models)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Claude/Anthropic accounts
|
||||||
|
// For OAuth and Setup-Token accounts: return default models
|
||||||
|
if account.IsOAuth() {
|
||||||
|
response.Success(c, claude.DefaultModels)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For API Key accounts: return models based on model_mapping
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if len(mapping) == 0 {
|
||||||
|
// No mapping configured, return default models
|
||||||
|
response.Success(c, claude.DefaultModels)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return mapped models (keys of the mapping are the available model IDs)
|
||||||
|
var models []claude.Model
|
||||||
|
for requestedModel := range mapping {
|
||||||
|
// Try to find display info from default models
|
||||||
|
var found bool
|
||||||
|
for _, dm := range claude.DefaultModels {
|
||||||
|
if dm.ID == requestedModel {
|
||||||
|
models = append(models, dm)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If not found in defaults, create a basic entry
|
||||||
|
if !found {
|
||||||
|
models = append(models, claude.Model{
|
||||||
|
ID: requestedModel,
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: requestedModel,
|
||||||
|
CreatedAt: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, models)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sub2api/internal/pkg/response"
|
|
||||||
"sub2api/internal/pkg/timezone"
|
|
||||||
"sub2api/internal/repository"
|
|
||||||
"sub2api/internal/service"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,17 +12,15 @@ import (
|
|||||||
|
|
||||||
// DashboardHandler handles admin dashboard statistics
|
// DashboardHandler handles admin dashboard statistics
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
adminService service.AdminService
|
dashboardService *service.DashboardService
|
||||||
usageRepo *repository.UsageLogRepository
|
startTime time.Time // Server start time for uptime calculation
|
||||||
startTime time.Time // Server start time for uptime calculation
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDashboardHandler creates a new admin dashboard handler
|
// NewDashboardHandler creates a new admin dashboard handler
|
||||||
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||||
return &DashboardHandler{
|
return &DashboardHandler{
|
||||||
adminService: adminService,
|
dashboardService: dashboardService,
|
||||||
usageRepo: usageRepo,
|
startTime: time.Now(),
|
||||||
startTime: time.Now(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
|||||||
// GetStats handles getting dashboard statistics
|
// GetStats handles getting dashboard statistics
|
||||||
// GET /api/v1/admin/dashboard/stats
|
// GET /api/v1/admin/dashboard/stats
|
||||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||||
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
|
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||||
return
|
return
|
||||||
@@ -110,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
|||||||
// 系统运行统计
|
// 系统运行统计
|
||||||
"average_duration_ms": stats.AverageDurationMs,
|
"average_duration_ms": stats.AverageDurationMs,
|
||||||
"uptime": uptime,
|
"uptime": uptime,
|
||||||
|
|
||||||
|
// 性能指标
|
||||||
|
"rpm": stats.Rpm,
|
||||||
|
"tpm": stats.Tpm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,12 +128,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
|||||||
|
|
||||||
// GetUsageTrend handles getting usage trend data
|
// GetUsageTrend handles getting usage trend data
|
||||||
// GET /api/v1/admin/dashboard/trend
|
// GET /api/v1/admin/dashboard/trend
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
|
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
granularity := c.DefaultQuery("granularity", "day")
|
granularity := c.DefaultQuery("granularity", "day")
|
||||||
|
|
||||||
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
|
// Parse optional filter params
|
||||||
|
var userID, apiKeyID int64
|
||||||
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
|
userID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||||
|
apiKeyID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
@@ -148,11 +162,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
|
|
||||||
// GetModelStats handles getting model usage statistics
|
// GetModelStats handles getting model usage statistics
|
||||||
// GET /api/v1/admin/dashboard/models
|
// GET /api/v1/admin/dashboard/models
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD)
|
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
|
// Parse optional filter params
|
||||||
|
var userID, apiKeyID int64
|
||||||
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
|
userID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||||
|
apiKeyID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
@@ -177,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
|||||||
limit = 5
|
limit = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage trend")
|
response.Error(c, 500, "Failed to get API key usage trend")
|
||||||
return
|
return
|
||||||
@@ -203,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
|||||||
limit = 12
|
limit = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage trend")
|
response.Error(c, 500, "Failed to get user usage trend")
|
||||||
return
|
return
|
||||||
@@ -232,11 +259,11 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.UserIDs) == 0 {
|
if len(req.UserIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage stats")
|
response.Error(c, 500, "Failed to get user usage stats")
|
||||||
return
|
return
|
||||||
@@ -260,11 +287,11 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.ApiKeyIDs) == 0 {
|
if len(req.ApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage stats")
|
response.Error(c, 500, "Failed to get API key usage stats")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||||
|
type OpenAIOAuthHandler struct {
|
||||||
|
openaiOAuthService *service.OpenAIOAuthService
|
||||||
|
adminService service.AdminService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||||
|
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||||
|
return &OpenAIOAuthHandler{
|
||||||
|
openaiOAuthService: openaiOAuthService,
|
||||||
|
adminService: adminService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||||
|
type OpenAIGenerateAuthURLRequest struct {
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||||
|
// POST /api/v1/admin/openai/generate-auth-url
|
||||||
|
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||||
|
var req OpenAIGenerateAuthURLRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
// Allow empty body
|
||||||
|
req = OpenAIGenerateAuthURLRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||||
|
type OpenAIExchangeCodeRequest struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||||
|
// POST /api/v1/admin/openai/exchange-code
|
||||||
|
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||||
|
var req OpenAIExchangeCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
|
SessionID: req.SessionID,
|
||||||
|
Code: req.Code,
|
||||||
|
RedirectURI: req.RedirectURI,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||||
|
type OpenAIRefreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
|
// POST /api/v1/admin/openai/refresh-token
|
||||||
|
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
|
var req OpenAIRefreshTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyURL string
|
||||||
|
if req.ProxyID != nil {
|
||||||
|
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||||
|
if err == nil && proxy != nil {
|
||||||
|
proxyURL = proxy.URL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Failed to refresh token: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||||
|
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||||
|
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid account ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get account
|
||||||
|
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "Account not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure account is OpenAI platform
|
||||||
|
if !account.IsOpenAI() {
|
||||||
|
response.BadRequest(c, "Account is not an OpenAI account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only refresh OAuth-based accounts
|
||||||
|
if !account.IsOAuth() {
|
||||||
|
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use OpenAI OAuth service to refresh token
|
||||||
|
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new credentials from token info
|
||||||
|
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
// Preserve non-token settings from existing credentials
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||||
|
Credentials: newCredentials,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, updatedAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||||
|
// POST /api/v1/admin/openai/create-from-oauth
|
||||||
|
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for tokens
|
||||||
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
|
SessionID: req.SessionID,
|
||||||
|
Code: req.Code,
|
||||||
|
RedirectURI: req.RedirectURI,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build credentials from token info
|
||||||
|
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
// Use email as default name if not provided
|
||||||
|
name := req.Name
|
||||||
|
if name == "" && tokenInfo.Email != "" {
|
||||||
|
name = tokenInfo.Email
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = "OpenAI OAuth Account"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create account
|
||||||
|
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||||
|
Name: name,
|
||||||
|
Platform: "openai",
|
||||||
|
Type: "oauth",
|
||||||
|
Credentials: credentials,
|
||||||
|
ProxyID: req.ProxyID,
|
||||||
|
Concurrency: req.Concurrency,
|
||||||
|
Priority: req.Priority,
|
||||||
|
GroupIDs: req.GroupIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to create account: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, account)
|
||||||
|
}
|
||||||
@@ -2,9 +2,10 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||||
Name: req.Name,
|
Name: strings.TrimSpace(req.Name),
|
||||||
Protocol: req.Protocol,
|
Protocol: strings.TrimSpace(req.Protocol),
|
||||||
Host: req.Host,
|
Host: strings.TrimSpace(req.Host),
|
||||||
Port: req.Port,
|
Port: req.Port,
|
||||||
Username: req.Username,
|
Username: strings.TrimSpace(req.Username),
|
||||||
Password: req.Password,
|
Password: strings.TrimSpace(req.Password),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
||||||
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
|
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
|
||||||
Name: req.Name,
|
Name: strings.TrimSpace(req.Name),
|
||||||
Protocol: req.Protocol,
|
Protocol: strings.TrimSpace(req.Protocol),
|
||||||
Host: req.Host,
|
Host: strings.TrimSpace(req.Host),
|
||||||
Port: req.Port,
|
Port: req.Port,
|
||||||
Username: req.Username,
|
Username: strings.TrimSpace(req.Username),
|
||||||
Password: req.Password,
|
Password: strings.TrimSpace(req.Password),
|
||||||
Status: req.Status,
|
Status: strings.TrimSpace(req.Status),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
||||||
@@ -235,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
response.Paginated(c, accounts, total, page, pageSize)
|
response.Paginated(c, accounts, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
type BatchCreateProxyItem struct {
|
type BatchCreateProxyItem struct {
|
||||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||||
@@ -263,8 +263,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
|||||||
skipped := 0
|
skipped := 0
|
||||||
|
|
||||||
for _, item := range req.Proxies {
|
for _, item := range req.Proxies {
|
||||||
|
// Trim all string fields
|
||||||
|
host := strings.TrimSpace(item.Host)
|
||||||
|
protocol := strings.TrimSpace(item.Protocol)
|
||||||
|
username := strings.TrimSpace(item.Username)
|
||||||
|
password := strings.TrimSpace(item.Password)
|
||||||
|
|
||||||
// Check for duplicates (same host, port, username, password)
|
// Check for duplicates (same host, port, username, password)
|
||||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
|
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
||||||
return
|
return
|
||||||
@@ -278,11 +284,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
|||||||
// Create proxy with default name
|
// Create proxy with default name
|
||||||
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||||
Name: "default",
|
Name: "default",
|
||||||
Protocol: item.Protocol,
|
Protocol: protocol,
|
||||||
Host: item.Host,
|
Host: host,
|
||||||
Port: item.Port,
|
Port: item.Port,
|
||||||
Username: item.Username,
|
Username: username,
|
||||||
Password: item.Password,
|
Password: password,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If creation fails due to duplicate, count as skipped
|
// If creation fails due to duplicate, count as skipped
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
|||||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"total_codes": 0,
|
"total_codes": 0,
|
||||||
"active_codes": 0,
|
"active_codes": 0,
|
||||||
"used_codes": 0,
|
"used_codes": 0,
|
||||||
"expired_codes": 0,
|
"expired_codes": 0,
|
||||||
"total_value_distributed": 0.0,
|
"total_value_distributed": 0.0,
|
||||||
"by_type": gin.H{
|
"by_type": gin.H{
|
||||||
"balance": 0,
|
"balance": 0,
|
||||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
writer := csv.NewWriter(&buf)
|
writer := csv.NewWriter(&buf)
|
||||||
|
|
||||||
// Write header
|
// Write header
|
||||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Write data rows
|
// Write data rows
|
||||||
for _, code := range codes {
|
for _, code := range codes {
|
||||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
if code.UsedAt != nil {
|
if code.UsedAt != nil {
|
||||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
writer.Write([]string{
|
if err := writer.Write([]string{
|
||||||
fmt.Sprintf("%d", code.ID),
|
fmt.Sprintf("%d", code.ID),
|
||||||
code.Code,
|
code.Code,
|
||||||
code.Type,
|
code.Type,
|
||||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
usedBy,
|
usedBy,
|
||||||
usedAt,
|
usedAt,
|
||||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
})
|
}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Flush()
|
writer.Flush()
|
||||||
|
if err := writer.Error(); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("Content-Type", "text/csv")
|
c.Header("Content-Type", "text/csv")
|
||||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
|
|||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
ApiBaseUrl string `json:"api_base_url"`
|
ApiBaseUrl string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
|
DocUrl string `json:"doc_url"`
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
@@ -104,6 +105,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
SiteSubtitle: req.SiteSubtitle,
|
SiteSubtitle: req.SiteSubtitle,
|
||||||
ApiBaseUrl: req.ApiBaseUrl,
|
ApiBaseUrl: req.ApiBaseUrl,
|
||||||
ContactInfo: req.ContactInfo,
|
ContactInfo: req.ContactInfo,
|
||||||
|
DocUrl: req.DocUrl,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
DefaultBalance: req.DefaultBalance,
|
DefaultBalance: req.DefaultBalance,
|
||||||
}
|
}
|
||||||
@@ -256,3 +258,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAdminApiKey 获取管理员 API Key 状态
|
||||||
|
// GET /api/v1/admin/settings/admin-api-key
|
||||||
|
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||||
|
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"exists": exists,
|
||||||
|
"masked_key": maskedKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||||
|
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||||
|
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||||
|
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"key": key, // 完整 key 只在生成时返回一次
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAdminApiKey 删除管理员 API Key
|
||||||
|
// DELETE /api/v1/admin/settings/admin-api-key
|
||||||
|
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||||
|
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||||
|
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
|
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
|
||||||
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
|
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/pkg/sysutil"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SystemHandler handles system-related operations
|
// SystemHandler handles system-related operations
|
||||||
@@ -18,9 +17,9 @@ type SystemHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSystemHandler creates a new SystemHandler
|
// NewSystemHandler creates a new SystemHandler
|
||||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||||
return &SystemHandler{
|
return &SystemHandler{
|
||||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
updateSvc: updateSvc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,34 +4,32 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UsageHandler handles admin usage-related requests
|
// UsageHandler handles admin usage-related requests
|
||||||
type UsageHandler struct {
|
type UsageHandler struct {
|
||||||
usageRepo *repository.UsageLogRepository
|
|
||||||
apiKeyRepo *repository.ApiKeyRepository
|
|
||||||
usageService *service.UsageService
|
usageService *service.UsageService
|
||||||
|
apiKeyService *service.ApiKeyService
|
||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUsageHandler creates a new admin usage handler
|
// NewUsageHandler creates a new admin usage handler
|
||||||
func NewUsageHandler(
|
func NewUsageHandler(
|
||||||
usageRepo *repository.UsageLogRepository,
|
|
||||||
apiKeyRepo *repository.ApiKeyRepository,
|
|
||||||
usageService *service.UsageService,
|
usageService *service.UsageService,
|
||||||
|
apiKeyService *service.ApiKeyService,
|
||||||
adminService service.AdminService,
|
adminService service.AdminService,
|
||||||
) *UsageHandler {
|
) *UsageHandler {
|
||||||
return &UsageHandler{
|
return &UsageHandler{
|
||||||
usageRepo: usageRepo,
|
usageService: usageService,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyService: apiKeyService,
|
||||||
usageService: usageService,
|
adminService: adminService,
|
||||||
adminService: adminService,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,15 +80,15 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
endTime = &t
|
endTime = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
filters := repository.UsageLogFilters{
|
filters := usagestats.UsageLogFilters{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
ApiKeyID: apiKeyID,
|
ApiKeyID: apiKeyID,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
EndTime: endTime,
|
EndTime: endTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
|
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||||
return
|
return
|
||||||
@@ -178,7 +176,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get global stats
|
// Get global stats
|
||||||
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||||
return
|
return
|
||||||
@@ -192,7 +190,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||||
keyword := c.Query("q")
|
keyword := c.Query("q")
|
||||||
if keyword == "" {
|
if keyword == "" {
|
||||||
response.Success(c, []interface{}{})
|
response.Success(c, []any{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,7 +234,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
|||||||
userID = id
|
userID = id
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
|
|||||||
type CreateUserRequest struct {
|
type CreateUserRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
Password string `json:"password" binding:"required,min=6"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Wechat string `json:"wechat"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
AllowedGroups []int64 `json:"allowed_groups"`
|
AllowedGroups []int64 `json:"allowed_groups"`
|
||||||
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
|
|||||||
type UpdateUserRequest struct {
|
type UpdateUserRequest struct {
|
||||||
Email string `json:"email" binding:"omitempty,email"`
|
Email string `json:"email" binding:"omitempty,email"`
|
||||||
Password string `json:"password" binding:"omitempty,min=6"`
|
Password string `json:"password" binding:"omitempty,min=6"`
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Wechat *string `json:"wechat"`
|
||||||
|
Notes *string `json:"notes"`
|
||||||
Balance *float64 `json:"balance"`
|
Balance *float64 `json:"balance"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||||
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
|
|||||||
|
|
||||||
// UpdateBalanceRequest represents balance update request
|
// UpdateBalanceRequest represents balance update request
|
||||||
type UpdateBalanceRequest struct {
|
type UpdateBalanceRequest struct {
|
||||||
Balance float64 `json:"balance" binding:"required"`
|
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all users with pagination
|
// List handles listing all users with pagination
|
||||||
@@ -94,6 +101,9 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
|
Username: req.Username,
|
||||||
|
Wechat: req.Wechat,
|
||||||
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
@@ -125,6 +135,9 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
|
Username: req.Username,
|
||||||
|
Wechat: req.Wechat,
|
||||||
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
@@ -171,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
|
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
|
|
||||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,28 +7,24 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// Maximum wait time for concurrency slot
|
|
||||||
maxConcurrencyWait = 60 * time.Second
|
|
||||||
// Ping interval during wait
|
|
||||||
pingInterval = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// GatewayHandler handles API gateway requests
|
// GatewayHandler handles API gateway requests
|
||||||
type GatewayHandler struct {
|
type GatewayHandler struct {
|
||||||
gatewayService *service.GatewayService
|
gatewayService *service.GatewayService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
concurrencyService *service.ConcurrencyService
|
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
|
concurrencyHelper *ConcurrencyHelper
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayHandler creates a new GatewayHandler
|
// NewGatewayHandler creates a new GatewayHandler
|
||||||
@@ -36,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
|||||||
return &GatewayHandler{
|
return &GatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
concurrencyService: concurrencyService,
|
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +83,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 0. 检查wait队列是否已满
|
// 0. 检查wait队列是否已满
|
||||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||||
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment wait count failed: %v", err)
|
log.Printf("Increment wait count failed: %v", err)
|
||||||
// On error, allow request to proceed
|
// On error, allow request to proceed
|
||||||
@@ -96,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 确保在函数退出时减少wait计数
|
// 确保在函数退出时减少wait计数
|
||||||
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
|
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||||
|
|
||||||
// 1. 首先获取用户并发槽位
|
// 1. 首先获取用户并发槽位
|
||||||
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("User concurrency acquire failed: %v", err)
|
log.Printf("User concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
@@ -126,8 +122,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
|
if req.Stream {
|
||||||
|
sendMockWarmupStream(c, req.Model)
|
||||||
|
} else {
|
||||||
|
sendMockWarmupResponse(c, req.Model)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
@@ -161,154 +167,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
|
|
||||||
// For streaming requests, sends ping events during the wait
|
|
||||||
// streamStarted is updated if streaming response has begun
|
|
||||||
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
|
||||||
ctx := c.Request.Context()
|
|
||||||
|
|
||||||
// Try to acquire immediately
|
|
||||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Acquired {
|
|
||||||
return result.ReleaseFunc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need to wait - handle streaming ping if needed
|
|
||||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
|
|
||||||
// For streaming requests, sends ping events during the wait
|
|
||||||
// streamStarted is updated if streaming response has begun
|
|
||||||
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
|
||||||
ctx := c.Request.Context()
|
|
||||||
|
|
||||||
// Try to acquire immediately
|
|
||||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Acquired {
|
|
||||||
return result.ReleaseFunc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need to wait - handle streaming ping if needed
|
|
||||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// concurrencyError represents a concurrency limit error with context
|
|
||||||
type concurrencyError struct {
|
|
||||||
SlotType string
|
|
||||||
IsTimeout bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *concurrencyError) Error() string {
|
|
||||||
if e.IsTimeout {
|
|
||||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
|
|
||||||
// Note: For streaming requests, we send ping to keep the connection alive.
|
|
||||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
|
|
||||||
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// For streaming requests, set up SSE headers for ping
|
|
||||||
var flusher http.Flusher
|
|
||||||
if isStream {
|
|
||||||
var ok bool
|
|
||||||
flusher, ok = c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("streaming not supported")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pingTicker := time.NewTicker(pingInterval)
|
|
||||||
defer pingTicker.Stop()
|
|
||||||
|
|
||||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
|
||||||
defer pollTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, &concurrencyError{
|
|
||||||
SlotType: slotType,
|
|
||||||
IsTimeout: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-pingTicker.C:
|
|
||||||
// Send ping for streaming requests to keep connection alive
|
|
||||||
if isStream && flusher != nil {
|
|
||||||
// Set headers on first ping (lazy initialization)
|
|
||||||
if !*streamStarted {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Accel-Buffering", "no")
|
|
||||||
*streamStarted = true
|
|
||||||
}
|
|
||||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-pollTicker.C:
|
|
||||||
// Try to acquire slot
|
|
||||||
var result *service.AcquireResult
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if slotType == "user" {
|
|
||||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
|
||||||
} else {
|
|
||||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Acquired {
|
|
||||||
return result.ReleaseFunc, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Models handles listing available models
|
// Models handles listing available models
|
||||||
// GET /v1/models
|
// GET /v1/models
|
||||||
|
// Returns different model lists based on the API key's group platform
|
||||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||||
models := []gin.H{
|
apiKey, _ := middleware.GetApiKeyFromContext(c)
|
||||||
{
|
|
||||||
"id": "claude-opus-4-5-20251101",
|
// Return OpenAI models for OpenAI platform groups
|
||||||
"type": "model",
|
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||||
"display_name": "Claude Opus 4.5",
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"created_at": "2025-11-01T00:00:00Z",
|
"object": "list",
|
||||||
},
|
"data": openai.DefaultModels,
|
||||||
{
|
})
|
||||||
"id": "claude-sonnet-4-5-20250929",
|
return
|
||||||
"type": "model",
|
|
||||||
"display_name": "Claude Sonnet 4.5",
|
|
||||||
"created_at": "2025-09-29T00:00:00Z",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-haiku-4-5-20251001",
|
|
||||||
"type": "model",
|
|
||||||
"display_name": "Claude Haiku 4.5",
|
|
||||||
"created_at": "2025-10-01T00:00:00Z",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default: Claude models
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": models,
|
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
"data": claude.DefaultModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,7 +300,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
if ok {
|
if ok {
|
||||||
// Send error event in SSE format
|
// Send error event in SSE format
|
||||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||||
fmt.Fprint(c.Writer, errorEvent)
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -443,3 +322,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokens handles token counting endpoint
|
||||||
|
// POST /v1/messages/count_tokens
|
||||||
|
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||||
|
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||||
|
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||||
|
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := middleware.GetUserFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取请求体
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求获取模型名
|
||||||
|
var req struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取订阅信息(可能为nil)
|
||||||
|
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
// 校验 billing eligibility(订阅/余额)
|
||||||
|
// 【注意】不计算并发,但需要校验订阅/余额
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算粘性会话 hash
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||||
|
|
||||||
|
// 选择支持该模型的账号
|
||||||
|
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发请求(不记录使用量)
|
||||||
|
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
|
||||||
|
log.Printf("Forward count_tokens request failed: %v", err)
|
||||||
|
// 错误响应已在 ForwardCountTokens 中处理
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||||
|
func isWarmupRequest(body []byte) bool {
|
||||||
|
// 快速检查:如果body不包含关键字,直接返回false
|
||||||
|
bodyStr := string(body)
|
||||||
|
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析完整请求
|
||||||
|
var req struct {
|
||||||
|
Messages []struct {
|
||||||
|
Content []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"messages"`
|
||||||
|
System []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"system"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 messages 中的标题提示模式
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
for _, content := range msg.Content {
|
||||||
|
if content.Type == "text" {
|
||||||
|
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||||
|
content.Text == "Warmup" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 system 中的标题提取模式
|
||||||
|
for _, system := range req.System {
|
||||||
|
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||||
|
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
events := []string{
|
||||||
|
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
|
||||||
|
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||||
|
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||||
|
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||||
|
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||||
|
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||||
|
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range events {
|
||||||
|
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||||
|
c.Writer.Flush()
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||||
|
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"id": "msg_mock_warmup",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": model,
|
||||||
|
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"usage": gin.H{
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 2,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
180
backend/internal/handler/gateway_helper.go
Normal file
180
backend/internal/handler/gateway_helper.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||||
|
maxConcurrencyWait = 30 * time.Second
|
||||||
|
// pingInterval is the interval for sending ping events during slot wait
|
||||||
|
pingInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||||
|
type SSEPingFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||||
|
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||||
|
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||||
|
SSEPingFormatNone SSEPingFormat = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConcurrencyError represents a concurrency limit error with context
|
||||||
|
type ConcurrencyError struct {
|
||||||
|
SlotType string
|
||||||
|
IsTimeout bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConcurrencyError) Error() string {
|
||||||
|
if e.IsTimeout {
|
||||||
|
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||||
|
type ConcurrencyHelper struct {
|
||||||
|
concurrencyService *service.ConcurrencyService
|
||||||
|
pingFormat SSEPingFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||||
|
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||||
|
return &ConcurrencyHelper{
|
||||||
|
concurrencyService: concurrencyService,
|
||||||
|
pingFormat: pingFormat,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementWaitCount increments the wait count for a user
|
||||||
|
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementWaitCount decrements the wait count for a user
|
||||||
|
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||||
|
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||||
|
// For streaming requests, sends ping events during the wait.
|
||||||
|
// streamStarted is updated if streaming response has begun.
|
||||||
|
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Try to acquire immediately
|
||||||
|
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Acquired {
|
||||||
|
return result.ReleaseFunc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to wait - handle streaming ping if needed
|
||||||
|
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||||
|
// For streaming requests, sends ping events during the wait.
|
||||||
|
// streamStarted is updated if streaming response has begun.
|
||||||
|
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Try to acquire immediately
|
||||||
|
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Acquired {
|
||||||
|
return result.ReleaseFunc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to wait - handle streaming ping if needed
|
||||||
|
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||||
|
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||||
|
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Determine if ping is needed (streaming + ping format defined)
|
||||||
|
needPing := isStream && h.pingFormat != ""
|
||||||
|
|
||||||
|
var flusher http.Flusher
|
||||||
|
if needPing {
|
||||||
|
var ok bool
|
||||||
|
flusher, ok = c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("streaming not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only create ping ticker if ping is needed
|
||||||
|
var pingCh <-chan time.Time
|
||||||
|
if needPing {
|
||||||
|
pingTicker := time.NewTicker(pingInterval)
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
pingCh = pingTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
|
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||||
|
defer pollTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, &ConcurrencyError{
|
||||||
|
SlotType: slotType,
|
||||||
|
IsTimeout: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-pingCh:
|
||||||
|
// Send ping to keep connection alive
|
||||||
|
if !*streamStarted {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
*streamStarted = true
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
case <-pollTicker.C:
|
||||||
|
// Try to acquire slot
|
||||||
|
var result *service.AcquireResult
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if slotType == "user" {
|
||||||
|
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||||
|
} else {
|
||||||
|
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Acquired {
|
||||||
|
return result.ReleaseFunc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/handler/admin"
|
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||||
"sub2api/internal/repository"
|
|
||||||
"sub2api/internal/service"
|
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminHandlers contains all admin-related HTTP handlers
|
// AdminHandlers contains all admin-related HTTP handlers
|
||||||
@@ -15,6 +11,7 @@ type AdminHandlers struct {
|
|||||||
Group *admin.GroupHandler
|
Group *admin.GroupHandler
|
||||||
Account *admin.AccountHandler
|
Account *admin.AccountHandler
|
||||||
OAuth *admin.OAuthHandler
|
OAuth *admin.OAuthHandler
|
||||||
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
Proxy *admin.ProxyHandler
|
Proxy *admin.ProxyHandler
|
||||||
Redeem *admin.RedeemHandler
|
Redeem *admin.RedeemHandler
|
||||||
Setting *admin.SettingHandler
|
Setting *admin.SettingHandler
|
||||||
@@ -25,15 +22,16 @@ type AdminHandlers struct {
|
|||||||
|
|
||||||
// Handlers contains all HTTP handlers
|
// Handlers contains all HTTP handlers
|
||||||
type Handlers struct {
|
type Handlers struct {
|
||||||
Auth *AuthHandler
|
Auth *AuthHandler
|
||||||
User *UserHandler
|
User *UserHandler
|
||||||
APIKey *APIKeyHandler
|
APIKey *APIKeyHandler
|
||||||
Usage *UsageHandler
|
Usage *UsageHandler
|
||||||
Redeem *RedeemHandler
|
Redeem *RedeemHandler
|
||||||
Subscription *SubscriptionHandler
|
Subscription *SubscriptionHandler
|
||||||
Admin *AdminHandlers
|
Admin *AdminHandlers
|
||||||
Gateway *GatewayHandler
|
Gateway *GatewayHandler
|
||||||
Setting *SettingHandler
|
OpenAIGateway *OpenAIGatewayHandler
|
||||||
|
Setting *SettingHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildInfo contains build-time information
|
// BuildInfo contains build-time information
|
||||||
@@ -41,30 +39,3 @@ type BuildInfo struct {
|
|||||||
Version string
|
Version string
|
||||||
BuildType string // "source" for manual builds, "release" for CI builds
|
BuildType string // "source" for manual builds, "release" for CI builds
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandlers creates a new Handlers instance with all handlers initialized
|
|
||||||
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
|
|
||||||
return &Handlers{
|
|
||||||
Auth: NewAuthHandler(services.Auth),
|
|
||||||
User: NewUserHandler(services.User),
|
|
||||||
APIKey: NewAPIKeyHandler(services.ApiKey),
|
|
||||||
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
|
|
||||||
Redeem: NewRedeemHandler(services.Redeem),
|
|
||||||
Subscription: NewSubscriptionHandler(services.Subscription),
|
|
||||||
Admin: &AdminHandlers{
|
|
||||||
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
|
|
||||||
User: admin.NewUserHandler(services.Admin),
|
|
||||||
Group: admin.NewGroupHandler(services.Admin),
|
|
||||||
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
|
|
||||||
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
|
|
||||||
Proxy: admin.NewProxyHandler(services.Admin),
|
|
||||||
Redeem: admin.NewRedeemHandler(services.Admin),
|
|
||||||
Setting: admin.NewSettingHandler(services.Setting, services.Email),
|
|
||||||
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
|
|
||||||
Subscription: admin.NewSubscriptionHandler(services.Subscription),
|
|
||||||
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
|
|
||||||
},
|
|
||||||
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
|
|
||||||
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
209
backend/internal/handler/openai_gateway_handler.go
Normal file
209
backend/internal/handler/openai_gateway_handler.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||||
|
type OpenAIGatewayHandler struct {
|
||||||
|
gatewayService *service.OpenAIGatewayService
|
||||||
|
billingCacheService *service.BillingCacheService
|
||||||
|
concurrencyHelper *ConcurrencyHelper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
|
func NewOpenAIGatewayHandler(
|
||||||
|
gatewayService *service.OpenAIGatewayService,
|
||||||
|
concurrencyService *service.ConcurrencyService,
|
||||||
|
billingCacheService *service.BillingCacheService,
|
||||||
|
) *OpenAIGatewayHandler {
|
||||||
|
return &OpenAIGatewayHandler{
|
||||||
|
gatewayService: gatewayService,
|
||||||
|
billingCacheService: billingCacheService,
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Responses handles OpenAI Responses API endpoint
|
||||||
|
// POST /openai/v1/responses
|
||||||
|
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||||
|
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||||
|
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := middleware.GetUserFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read request body
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request body to map for potential modification
|
||||||
|
var reqBody map[string]any
|
||||||
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model and stream
|
||||||
|
reqModel, _ := reqBody["model"].(string)
|
||||||
|
reqStream, _ := reqBody["stream"].(bool)
|
||||||
|
|
||||||
|
// For non-Codex CLI requests, set default instructions
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
if !openai.IsCodexCLIRequest(userAgent) {
|
||||||
|
reqBody["instructions"] = openai.DefaultInstructions
|
||||||
|
// Re-serialize body
|
||||||
|
body, err = json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track if we've started streaming (for error handling)
|
||||||
|
streamStarted := false
|
||||||
|
|
||||||
|
// Get subscription info (may be nil)
|
||||||
|
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
// 0. Check if wait queue is full
|
||||||
|
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment wait count failed: %v", err)
|
||||||
|
// On error, allow request to proceed
|
||||||
|
} else if !canWait {
|
||||||
|
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Ensure wait count is decremented when function exits
|
||||||
|
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||||
|
|
||||||
|
// 1. First acquire user concurrency slot
|
||||||
|
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("User concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Re-check billing eligibility after wait
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session hash (from header for OpenAI)
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||||
|
|
||||||
|
// Select account supporting the requested model
|
||||||
|
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||||
|
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||||
|
|
||||||
|
// 3. Acquire account concurrency slot
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
defer accountReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward request
|
||||||
|
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
if err != nil {
|
||||||
|
// Error response already handled in Forward, just log
|
||||||
|
log.Printf("Forward request failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async record usage
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: user,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||||
|
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||||
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||||
|
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
|
if streamStarted {
|
||||||
|
// Stream already started, send error as SSE event then close
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if ok {
|
||||||
|
// Send error event in OpenAI SSE format
|
||||||
|
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||||
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal case: return JSON response with proper status code
|
||||||
|
h.errorResponse(c, status, errType, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorResponse returns OpenAI API format error response
|
||||||
|
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
|
c.JSON(status, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -16,15 +16,13 @@ import (
|
|||||||
// UsageHandler handles usage-related requests
|
// UsageHandler handles usage-related requests
|
||||||
type UsageHandler struct {
|
type UsageHandler struct {
|
||||||
usageService *service.UsageService
|
usageService *service.UsageService
|
||||||
usageRepo *repository.UsageLogRepository
|
|
||||||
apiKeyService *service.ApiKeyService
|
apiKeyService *service.ApiKeyService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUsageHandler creates a new UsageHandler
|
// NewUsageHandler creates a new UsageHandler
|
||||||
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
|
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||||
return &UsageHandler{
|
return &UsageHandler{
|
||||||
usageService: usageService,
|
usageService: usageService,
|
||||||
usageRepo: usageRepo,
|
|
||||||
apiKeyService: apiKeyService,
|
apiKeyService: apiKeyService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,9 +66,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
apiKeyID = id
|
apiKeyID = id
|
||||||
}
|
}
|
||||||
|
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
var records []model.UsageLog
|
var records []model.UsageLog
|
||||||
var result *repository.PaginationResult
|
var result *pagination.PaginationResult
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if apiKeyID > 0 {
|
if apiKeyID > 0 {
|
||||||
@@ -259,7 +257,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
|
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to get dashboard statistics")
|
response.InternalError(c, "Failed to get dashboard statistics")
|
||||||
return
|
return
|
||||||
@@ -286,7 +284,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
|||||||
startTime, endTime := parseUserTimeRange(c)
|
startTime, endTime := parseUserTimeRange(c)
|
||||||
granularity := c.DefaultQuery("granularity", "day")
|
granularity := c.DefaultQuery("granularity", "day")
|
||||||
|
|
||||||
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to get usage trend")
|
response.InternalError(c, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
@@ -317,7 +315,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
|||||||
|
|
||||||
startTime, endTime := parseUserTimeRange(c)
|
startTime, endTime := parseUserTimeRange(c)
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to get model statistics")
|
response.InternalError(c, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
@@ -357,12 +355,12 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.ApiKeyIDs) == 0 {
|
if len(req.ApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify ownership of all requested API keys
|
// Verify ownership of all requested API keys
|
||||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
|
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to verify API key ownership")
|
response.InternalError(c, "Failed to verify API key ownership")
|
||||||
return
|
return
|
||||||
@@ -382,11 +380,11 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(validApiKeyIDs) == 0 {
|
if len(validApiKeyIDs) == 0 {
|
||||||
response.Success(c, gin.H{"stats": map[string]interface{}{}})
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to get API key usage stats")
|
response.InternalError(c, "Failed to get API key usage stats")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
|
|||||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateProfileRequest represents the update profile request payload
|
||||||
|
type UpdateProfileRequest struct {
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Wechat *string `json:"wechat"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetProfile handles getting user profile
|
// GetProfile handles getting user profile
|
||||||
// GET /api/v1/users/me
|
// GET /api/v1/users/me
|
||||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||||
@@ -47,6 +53,9 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清空notes字段,普通用户不应看到备注
|
||||||
|
userData.Notes = ""
|
||||||
|
|
||||||
response.Success(c, userData)
|
response.Success(c, userData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,3 +92,40 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateProfile handles updating user profile
|
||||||
|
// PUT /api/v1/users/me
|
||||||
|
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||||
|
userValue, exists := c.Get("user")
|
||||||
|
if !exists {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := userValue.(*model.User)
|
||||||
|
if !ok {
|
||||||
|
response.InternalError(c, "Invalid user context")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
svcReq := service.UpdateProfileRequest{
|
||||||
|
Username: req.Username,
|
||||||
|
Wechat: req.Wechat,
|
||||||
|
}
|
||||||
|
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Failed to update profile: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空notes字段,普通用户不应看到备注
|
||||||
|
updatedUser.Notes = ""
|
||||||
|
|
||||||
|
response.Success(c, updatedUser)
|
||||||
|
}
|
||||||
|
|||||||
108
backend/internal/handler/wire.go
Normal file
108
backend/internal/handler/wire.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||||
|
func ProvideAdminHandlers(
|
||||||
|
dashboardHandler *admin.DashboardHandler,
|
||||||
|
userHandler *admin.UserHandler,
|
||||||
|
groupHandler *admin.GroupHandler,
|
||||||
|
accountHandler *admin.AccountHandler,
|
||||||
|
oauthHandler *admin.OAuthHandler,
|
||||||
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
|
proxyHandler *admin.ProxyHandler,
|
||||||
|
redeemHandler *admin.RedeemHandler,
|
||||||
|
settingHandler *admin.SettingHandler,
|
||||||
|
systemHandler *admin.SystemHandler,
|
||||||
|
subscriptionHandler *admin.SubscriptionHandler,
|
||||||
|
usageHandler *admin.UsageHandler,
|
||||||
|
) *AdminHandlers {
|
||||||
|
return &AdminHandlers{
|
||||||
|
Dashboard: dashboardHandler,
|
||||||
|
User: userHandler,
|
||||||
|
Group: groupHandler,
|
||||||
|
Account: accountHandler,
|
||||||
|
OAuth: oauthHandler,
|
||||||
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
|
Proxy: proxyHandler,
|
||||||
|
Redeem: redeemHandler,
|
||||||
|
Setting: settingHandler,
|
||||||
|
System: systemHandler,
|
||||||
|
Subscription: subscriptionHandler,
|
||||||
|
Usage: usageHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||||
|
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||||
|
return admin.NewSystemHandler(updateService)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||||
|
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
||||||
|
return NewSettingHandler(settingService, buildInfo.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideHandlers creates the Handlers struct
|
||||||
|
func ProvideHandlers(
|
||||||
|
authHandler *AuthHandler,
|
||||||
|
userHandler *UserHandler,
|
||||||
|
apiKeyHandler *APIKeyHandler,
|
||||||
|
usageHandler *UsageHandler,
|
||||||
|
redeemHandler *RedeemHandler,
|
||||||
|
subscriptionHandler *SubscriptionHandler,
|
||||||
|
adminHandlers *AdminHandlers,
|
||||||
|
gatewayHandler *GatewayHandler,
|
||||||
|
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||||
|
settingHandler *SettingHandler,
|
||||||
|
) *Handlers {
|
||||||
|
return &Handlers{
|
||||||
|
Auth: authHandler,
|
||||||
|
User: userHandler,
|
||||||
|
APIKey: apiKeyHandler,
|
||||||
|
Usage: usageHandler,
|
||||||
|
Redeem: redeemHandler,
|
||||||
|
Subscription: subscriptionHandler,
|
||||||
|
Admin: adminHandlers,
|
||||||
|
Gateway: gatewayHandler,
|
||||||
|
OpenAIGateway: openaiGatewayHandler,
|
||||||
|
Setting: settingHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderSet is the Wire provider set for all handlers
|
||||||
|
var ProviderSet = wire.NewSet(
|
||||||
|
// Top-level handlers
|
||||||
|
NewAuthHandler,
|
||||||
|
NewUserHandler,
|
||||||
|
NewAPIKeyHandler,
|
||||||
|
NewUsageHandler,
|
||||||
|
NewRedeemHandler,
|
||||||
|
NewSubscriptionHandler,
|
||||||
|
NewGatewayHandler,
|
||||||
|
NewOpenAIGatewayHandler,
|
||||||
|
ProvideSettingHandler,
|
||||||
|
|
||||||
|
// Admin handlers
|
||||||
|
admin.NewDashboardHandler,
|
||||||
|
admin.NewUserHandler,
|
||||||
|
admin.NewGroupHandler,
|
||||||
|
admin.NewAccountHandler,
|
||||||
|
admin.NewOAuthHandler,
|
||||||
|
admin.NewOpenAIOAuthHandler,
|
||||||
|
admin.NewProxyHandler,
|
||||||
|
admin.NewRedeemHandler,
|
||||||
|
admin.NewSettingHandler,
|
||||||
|
ProvideSystemHandler,
|
||||||
|
admin.NewSubscriptionHandler,
|
||||||
|
admin.NewUsageHandler,
|
||||||
|
|
||||||
|
// AdminHandlers and Handlers constructors
|
||||||
|
ProvideAdminHandlers,
|
||||||
|
ProvideHandlers,
|
||||||
|
)
|
||||||
38
backend/internal/infrastructure/database.go
Normal file
38
backend/internal/infrastructure/database.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package infrastructure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitDB 初始化数据库连接
|
||||||
|
func InitDB(cfg *config.Config) (*gorm.DB, error) {
|
||||||
|
// 初始化时区(在数据库连接之前,确保时区设置正确)
|
||||||
|
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gormConfig := &gorm.Config{}
|
||||||
|
if cfg.Server.Mode == "debug" {
|
||||||
|
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用带时区的 DSN 连接数据库
|
||||||
|
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||||
|
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||||
|
if err := model.AutoMigrate(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
16
backend/internal/infrastructure/redis.go
Normal file
16
backend/internal/infrastructure/redis.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package infrastructure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitRedis 初始化 Redis 客户端
|
||||||
|
func InitRedis(cfg *config.Config) *redis.Client {
|
||||||
|
return redis.NewClient(&redis.Options{
|
||||||
|
Addr: cfg.Redis.Address(),
|
||||||
|
Password: cfg.Redis.Password,
|
||||||
|
DB: cfg.Redis.DB,
|
||||||
|
})
|
||||||
|
}
|
||||||
25
backend/internal/infrastructure/wire.go
Normal file
25
backend/internal/infrastructure/wire.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package infrastructure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
|
||||||
|
"github.com/google/wire"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderSet 提供基础设施层的依赖
|
||||||
|
var ProviderSet = wire.NewSet(
|
||||||
|
ProvideDB,
|
||||||
|
ProvideRedis,
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProvideDB 提供数据库连接
|
||||||
|
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
|
||||||
|
return InitDB(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideRedis 提供 Redis 客户端
|
||||||
|
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||||
|
return InitRedis(cfg)
|
||||||
|
}
|
||||||
130
backend/internal/middleware/admin_auth.go
Normal file
130
backend/internal/middleware/admin_auth.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminAuth 管理员认证中间件
|
||||||
|
// 支持两种认证方式(通过不同的 header 区分):
|
||||||
|
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||||
|
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||||
|
func AdminAuth(
|
||||||
|
authService *service.AuthService,
|
||||||
|
userRepo interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||||
|
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||||
|
},
|
||||||
|
settingService *service.SettingService,
|
||||||
|
) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 检查 x-api-key header(Admin API Key 认证)
|
||||||
|
apiKey := c.GetHeader("x-api-key")
|
||||||
|
if apiKey != "" {
|
||||||
|
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 Authorization header(JWT 认证)
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
if authHeader != "" {
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||||
|
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无有效认证信息
|
||||||
|
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateAdminApiKey 验证管理员 API Key
|
||||||
|
func validateAdminApiKey(
|
||||||
|
c *gin.Context,
|
||||||
|
key string,
|
||||||
|
settingService *service.SettingService,
|
||||||
|
userRepo interface {
|
||||||
|
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||||
|
},
|
||||||
|
) bool {
|
||||||
|
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||||
|
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||||
|
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取真实的管理员用户
|
||||||
|
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(string(ContextKeyUser), admin)
|
||||||
|
c.Set("auth_method", "admin_api_key")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||||
|
func validateJWTForAdmin(
|
||||||
|
c *gin.Context,
|
||||||
|
token string,
|
||||||
|
authService *service.AuthService,
|
||||||
|
userRepo interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||||
|
},
|
||||||
|
) bool {
|
||||||
|
// 验证 JWT token
|
||||||
|
claims, err := authService.ValidateToken(token)
|
||||||
|
if err != nil {
|
||||||
|
if err == service.ErrTokenExpired {
|
||||||
|
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取用户
|
||||||
|
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户状态
|
||||||
|
if !user.IsActive() {
|
||||||
|
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查管理员权限
|
||||||
|
if user.Role != model.RoleAdmin {
|
||||||
|
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(string(ContextKeyUser), user)
|
||||||
|
c.Set("auth_method", "jwt")
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sub2api/internal/model"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sub2api/internal/model"
|
|
||||||
"sub2api/internal/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// JSONB 用于存储JSONB数据
|
// JSONB 用于存储JSONB数据
|
||||||
type JSONB map[string]interface{}
|
type JSONB map[string]any
|
||||||
|
|
||||||
func (j JSONB) Value() (driver.Value, error) {
|
func (j JSONB) Value() (driver.Value, error) {
|
||||||
if j == nil {
|
if j == nil {
|
||||||
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
|
|||||||
return json.Marshal(j)
|
return json.Marshal(j)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *JSONB) Scan(value interface{}) error {
|
func (j *JSONB) Scan(value any) error {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
*j = nil
|
*j = nil
|
||||||
return nil
|
return nil
|
||||||
@@ -40,8 +40,8 @@ type Account struct {
|
|||||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
@@ -68,7 +68,8 @@ type Account struct {
|
|||||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||||
|
|
||||||
// 虚拟字段 (不存储到数据库)
|
// 虚拟字段 (不存储到数据库)
|
||||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||||
|
Groups []*Group `gorm:"-" json:"groups,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Account) TableName() string {
|
func (Account) TableName() string {
|
||||||
@@ -145,7 +146,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 处理map[string]interface{}类型
|
// 处理map[string]interface{}类型
|
||||||
if m, ok := raw.(map[string]interface{}); ok {
|
if m, ok := raw.(map[string]any); ok {
|
||||||
result := make(map[string]string)
|
result := make(map[string]string)
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
@@ -163,7 +164,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
// 如果没有设置模型映射,则支持所有模型
|
// 如果没有设置模型映射,则支持所有模型
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return true // 没有映射配置,支持所有模型
|
return true // 没有映射配置,支持所有模型
|
||||||
}
|
}
|
||||||
_, exists := mapping[requestedModel]
|
_, exists := mapping[requestedModel]
|
||||||
@@ -174,7 +175,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
// 如果没有映射,返回原始模型名
|
// 如果没有映射,返回原始模型名
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel
|
return requestedModel
|
||||||
}
|
}
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
@@ -231,7 +232,7 @@ func (a *Account) GetCustomErrorCodes() []int {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||||
if arr, ok := raw.([]interface{}); ok {
|
if arr, ok := raw.([]any); ok {
|
||||||
result := make([]int, 0, len(arr))
|
result := make([]int, 0, len(arr))
|
||||||
for _, v := range arr {
|
for _, v := range arr {
|
||||||
// JSON 数字默认解析为 float64
|
// JSON 数字默认解析为 float64
|
||||||
@@ -263,3 +264,152 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
|
||||||
|
// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token
|
||||||
|
func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||||
|
if a.Credentials == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
|
||||||
|
if enabled, ok := v.(bool); ok {
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============== OpenAI 相关方法 ===============
|
||||||
|
|
||||||
|
// IsOpenAI 检查是否为 OpenAI 平台账号
|
||||||
|
func (a *Account) IsOpenAI() bool {
|
||||||
|
return a.Platform == PlatformOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAnthropic 检查是否为 Anthropic 平台账号
|
||||||
|
func (a *Account) IsAnthropic() bool {
|
||||||
|
return a.Platform == PlatformAnthropic
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
|
||||||
|
func (a *Account) IsOpenAIOAuth() bool {
|
||||||
|
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
|
||||||
|
func (a *Account) IsOpenAIApiKey() bool {
|
||||||
|
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
|
||||||
|
// 对于 API Key 类型账号,从 credentials 中获取 base_url
|
||||||
|
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
|
||||||
|
func (a *Account) GetOpenAIBaseURL() string {
|
||||||
|
if !a.IsOpenAI() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if a.Type == AccountTypeApiKey {
|
||||||
|
baseURL := a.GetCredential("base_url")
|
||||||
|
if baseURL != "" {
|
||||||
|
return baseURL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "https://api.openai.com" // OpenAI 默认 API URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
|
||||||
|
func (a *Account) GetOpenAIAccessToken() string {
|
||||||
|
if !a.IsOpenAI() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("access_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
|
||||||
|
func (a *Account) GetOpenAIRefreshToken() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("refresh_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
|
||||||
|
func (a *Account) GetOpenAIIDToken() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("id_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
|
||||||
|
func (a *Account) GetOpenAIApiKey() string {
|
||||||
|
if !a.IsOpenAIApiKey() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("api_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
|
||||||
|
// 返回空字符串表示透传原始 User-Agent
|
||||||
|
func (a *Account) GetOpenAIUserAgent() string {
|
||||||
|
if !a.IsOpenAI() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("user_agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
|
||||||
|
func (a *Account) GetChatGPTAccountID() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("chatgpt_account_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
|
||||||
|
func (a *Account) GetChatGPTUserID() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("chatgpt_user_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
|
||||||
|
func (a *Account) GetOpenAIOrganizationID() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.GetCredential("organization_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
|
||||||
|
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
expiresAtStr := a.GetCredential("expires_at")
|
||||||
|
if expiresAtStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 尝试解析时间
|
||||||
|
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||||
|
if err != nil {
|
||||||
|
// 尝试解析为 Unix 时间戳
|
||||||
|
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||||
|
t = time.Unix(int64(v), 0)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
|
||||||
|
func (a *Account) IsOpenAITokenExpired() bool {
|
||||||
|
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||||
|
if expiresAt == nil {
|
||||||
|
return false // 没有过期时间信息,假设未过期
|
||||||
|
}
|
||||||
|
// 提前 60 秒认为过期,便于刷新
|
||||||
|
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||||
Description string `gorm:"type:text" json:"description"`
|
Description string `gorm:"type:text" json:"description"`
|
||||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||||
|
|
||||||
// 订阅功能字段
|
// 订阅功能字段
|
||||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||||
|
|||||||
@@ -9,15 +9,16 @@ import (
|
|||||||
type RedeemCode struct {
|
type RedeemCode struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||||
UsedAt *time.Time `json:"used_at"`
|
UsedAt *time.Time `json:"used_at"`
|
||||||
|
Notes string `gorm:"type:text" json:"notes"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
|
|
||||||
// 订阅类型专用字段
|
// 订阅类型专用字段
|
||||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||||
|
|
||||||
// 关联
|
// 关联
|
||||||
@@ -40,8 +41,10 @@ func (r *RedeemCode) CanUse() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRedeemCode 生成唯一的兑换码
|
// GenerateRedeemCode 生成唯一的兑换码
|
||||||
func GenerateRedeemCode() string {
|
func GenerateRedeemCode() (string, error) {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
rand.Read(b)
|
if _, err := rand.Read(b); err != nil {
|
||||||
return hex.EncodeToString(b)
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
|||||||
// 设置Key常量
|
// 设置Key常量
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||||
|
|
||||||
// Cloudflare Turnstile 设置
|
// Cloudflare Turnstile 设置
|
||||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||||
@@ -42,12 +42,19 @@ const (
|
|||||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||||
|
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||||
|
|
||||||
|
// 管理员 API Key
|
||||||
|
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
|
||||||
|
const AdminApiKeyPrefix = "admin-"
|
||||||
|
|
||||||
// SystemSettings 系统设置结构体(用于API响应)
|
// SystemSettings 系统设置结构体(用于API响应)
|
||||||
type SystemSettings struct {
|
type SystemSettings struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
@@ -74,6 +81,7 @@ type SystemSettings struct {
|
|||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
ApiBaseUrl string `json:"api_base_url"`
|
ApiBaseUrl string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
|
DocUrl string `json:"doc_url"`
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
@@ -91,5 +99,6 @@ type PublicSettings struct {
|
|||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
ApiBaseUrl string `json:"api_base_url"`
|
ApiBaseUrl string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
|
DocUrl string `json:"doc_url"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
|||||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||||
|
Username string `gorm:"size:100;default:''" json:"username"`
|
||||||
|
Wechat string `gorm:"size:100;default:''" json:"wechat"`
|
||||||
|
Notes string `gorm:"type:text;default:''" json:"notes"`
|
||||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||||
@@ -22,7 +25,8 @@ type User struct {
|
|||||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||||
|
|
||||||
// 关联
|
// 关联
|
||||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||||
|
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (User) TableName() string {
|
func (User) TableName() string {
|
||||||
|
|||||||
74
backend/internal/pkg/claude/constants.go
Normal file
74
backend/internal/pkg/claude/constants.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
// Claude Code 客户端相关常量
|
||||||
|
|
||||||
|
// Beta header 常量
|
||||||
|
const (
|
||||||
|
BetaOAuth = "oauth-2025-04-20"
|
||||||
|
BetaClaudeCode = "claude-code-20250219"
|
||||||
|
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||||
|
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||||
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// Claude Code 客户端默认请求头
|
||||||
|
var DefaultHeaders = map[string]string{
|
||||||
|
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||||
|
"X-Stainless-Lang": "js",
|
||||||
|
"X-Stainless-Package-Version": "0.52.0",
|
||||||
|
"X-Stainless-OS": "Linux",
|
||||||
|
"X-Stainless-Arch": "x64",
|
||||||
|
"X-Stainless-Runtime": "node",
|
||||||
|
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||||
|
"X-Stainless-Retry-Count": "0",
|
||||||
|
"X-Stainless-Timeout": "60",
|
||||||
|
"X-App": "cli",
|
||||||
|
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model 表示一个 Claude 模型
|
||||||
|
type Model struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModels Claude Code 客户端支持的默认模型列表
|
||||||
|
var DefaultModels = []Model{
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-5-20251101",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: "Claude Opus 4.5",
|
||||||
|
CreatedAt: "2025-11-01T00:00:00Z",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-5-20250929",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: "Claude Sonnet 4.5",
|
||||||
|
CreatedAt: "2025-09-29T00:00:00Z",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-haiku-4-5-20251001",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: "Claude Haiku 4.5",
|
||||||
|
CreatedAt: "2025-10-01T00:00:00Z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelIDs 返回默认模型的 ID 列表
|
||||||
|
func DefaultModelIDs() []string {
|
||||||
|
ids := make([]string, len(DefaultModels))
|
||||||
|
for i, m := range DefaultModels {
|
||||||
|
ids[i] = m.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultTestModel 测试时使用的默认模型
|
||||||
|
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||||
@@ -43,18 +43,25 @@ type OAuthSession struct {
|
|||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sessions map[string]*OAuthSession
|
sessions map[string]*OAuthSession
|
||||||
|
stopCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionStore creates a new session store
|
// NewSessionStore creates a new session store
|
||||||
func NewSessionStore() *SessionStore {
|
func NewSessionStore() *SessionStore {
|
||||||
store := &SessionStore{
|
store := &SessionStore{
|
||||||
sessions: make(map[string]*OAuthSession),
|
sessions: make(map[string]*OAuthSession),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// Start cleanup goroutine
|
// Start cleanup goroutine
|
||||||
go store.cleanup()
|
go store.cleanup()
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the cleanup goroutine
|
||||||
|
func (s *SessionStore) Stop() {
|
||||||
|
close(s.stopCh)
|
||||||
|
}
|
||||||
|
|
||||||
// Set stores a session
|
// Set stores a session
|
||||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
|
|||||||
// cleanup removes expired sessions periodically
|
// cleanup removes expired sessions periodically
|
||||||
func (s *SessionStore) cleanup() {
|
func (s *SessionStore) cleanup() {
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
for range ticker.C {
|
defer ticker.Stop()
|
||||||
s.mu.Lock()
|
for {
|
||||||
for id, session := range s.sessions {
|
select {
|
||||||
if time.Since(session.CreatedAt) > SessionTTL {
|
case <-s.stopCh:
|
||||||
delete(s.sessions, id)
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.mu.Lock()
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
if time.Since(session.CreatedAt) > SessionTTL {
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
42
backend/internal/pkg/openai/constants.go
Normal file
42
backend/internal/pkg/openai/constants.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import _ "embed"
|
||||||
|
|
||||||
|
// Model represents an OpenAI model
|
||||||
|
type Model struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModels OpenAI models list
|
||||||
|
var DefaultModels = []Model{
|
||||||
|
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||||
|
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||||
|
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||||
|
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||||
|
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||||
|
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||||
|
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelIDs returns the default model ID list
|
||||||
|
func DefaultModelIDs() []string {
|
||||||
|
ids := make([]string, len(DefaultModels))
|
||||||
|
for i, m := range DefaultModels {
|
||||||
|
ids[i] = m.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultTestModel default model for testing OpenAI accounts
|
||||||
|
const DefaultTestModel = "gpt-5.1-codex"
|
||||||
|
|
||||||
|
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||||
|
// Content loaded from instructions.txt at compile time
|
||||||
|
//
|
||||||
|
//go:embed instructions.txt
|
||||||
|
var DefaultInstructions string
|
||||||
118
backend/internal/pkg/openai/instructions.txt
Normal file
118
backend/internal/pkg/openai/instructions.txt
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||||
|
|
||||||
|
## General
|
||||||
|
|
||||||
|
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||||
|
|
||||||
|
## Editing constraints
|
||||||
|
|
||||||
|
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||||
|
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||||
|
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||||
|
- You may be in a dirty git worktree.
|
||||||
|
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||||
|
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||||
|
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||||
|
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||||
|
- Do not amend a commit unless explicitly requested to do so.
|
||||||
|
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||||
|
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||||
|
|
||||||
|
## Plan tool
|
||||||
|
|
||||||
|
When using the planning tool:
|
||||||
|
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||||
|
- Do not make single-step plans.
|
||||||
|
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||||
|
|
||||||
|
## Codex CLI harness, sandboxing, and approvals
|
||||||
|
|
||||||
|
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||||
|
|
||||||
|
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||||
|
- **read-only**: The sandbox only permits reading files.
|
||||||
|
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||||
|
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||||
|
|
||||||
|
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||||
|
- **restricted**: Requires approval
|
||||||
|
- **enabled**: No approval needed
|
||||||
|
|
||||||
|
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||||
|
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||||
|
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||||
|
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||||
|
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||||
|
|
||||||
|
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||||
|
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||||
|
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||||
|
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||||
|
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||||
|
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||||
|
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||||
|
|
||||||
|
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||||
|
|
||||||
|
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||||
|
|
||||||
|
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||||
|
|
||||||
|
When requesting approval to execute a command that will require escalated privileges:
|
||||||
|
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||||
|
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||||
|
|
||||||
|
## Special user requests
|
||||||
|
|
||||||
|
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||||
|
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||||
|
|
||||||
|
## Frontend tasks
|
||||||
|
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||||
|
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||||
|
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||||
|
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||||
|
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||||
|
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||||
|
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||||
|
- Ensure the page loads properly on both desktop and mobile
|
||||||
|
|
||||||
|
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||||
|
|
||||||
|
## Presenting your work and final message
|
||||||
|
|
||||||
|
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||||
|
|
||||||
|
- Default: be very concise; friendly coding teammate tone.
|
||||||
|
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||||
|
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||||
|
- Skip heavy formatting for simple confirmations.
|
||||||
|
- Don't dump large files you've written; reference paths only.
|
||||||
|
- No \"save/copy this file\" - User is on the same machine.
|
||||||
|
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||||
|
- For code changes:
|
||||||
|
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||||
|
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||||
|
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||||
|
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||||
|
|
||||||
|
### Final answer structure and style guidelines
|
||||||
|
|
||||||
|
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||||
|
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||||
|
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||||
|
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||||
|
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||||
|
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||||
|
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||||
|
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||||
|
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||||
|
- File References: When referencing files in your response follow the below rules:
|
||||||
|
* Use inline code to make file paths clickable.
|
||||||
|
* Each reference should have a stand alone path. Even if it's the same file.
|
||||||
|
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||||
|
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||||
|
* Do not use URIs like file://, vscode://, or https://.
|
||||||
|
* Do not provide range of lines
|
||||||
|
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||||
|
|
||||||
366
backend/internal/pkg/openai/oauth.go
Normal file
366
backend/internal/pkg/openai/oauth.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||||
|
const (
|
||||||
|
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||||
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
|
||||||
|
// OAuth endpoints
|
||||||
|
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||||
|
TokenURL = "https://auth.openai.com/oauth/token"
|
||||||
|
|
||||||
|
// Default redirect URI (can be customized)
|
||||||
|
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||||
|
|
||||||
|
// Scopes
|
||||||
|
DefaultScopes = "openid profile email offline_access"
|
||||||
|
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||||
|
RefreshScopes = "openid profile email"
|
||||||
|
|
||||||
|
// Session TTL
|
||||||
|
SessionTTL = 30 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthSession stores OAuth flow state for OpenAI
|
||||||
|
type OAuthSession struct {
|
||||||
|
State string `json:"state"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
ProxyURL string `json:"proxy_url,omitempty"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStore manages OAuth sessions in memory
|
||||||
|
type SessionStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*OAuthSession
|
||||||
|
stopCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionStore creates a new session store
|
||||||
|
func NewSessionStore() *SessionStore {
|
||||||
|
store := &SessionStore{
|
||||||
|
sessions: make(map[string]*OAuthSession),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
// Start cleanup goroutine
|
||||||
|
go store.cleanup()
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a session
|
||||||
|
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sessions[sessionID] = session
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a session
|
||||||
|
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
session, ok := s.sessions[sessionID]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Check if expired
|
||||||
|
if time.Since(session.CreatedAt) > SessionTTL {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return session, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a session
|
||||||
|
func (s *SessionStore) Delete(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the cleanup goroutine
|
||||||
|
func (s *SessionStore) Stop() {
|
||||||
|
close(s.stopCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes expired sessions periodically
|
||||||
|
func (s *SessionStore) cleanup() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.mu.Lock()
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
if time.Since(session.CreatedAt) > SessionTTL {
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||||
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateState generates a random state string for OAuth
|
||||||
|
func GenerateState() (string, error) {
|
||||||
|
bytes, err := GenerateRandomBytes(32)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSessionID generates a unique session ID
|
||||||
|
func GenerateSessionID() (string, error) {
|
||||||
|
bytes, err := GenerateRandomBytes(16)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||||
|
// OpenAI uses hex encoding instead of base64url
|
||||||
|
func GenerateCodeVerifier() (string, error) {
|
||||||
|
bytes, err := GenerateRandomBytes(64)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||||
|
// Uses base64url encoding as per RFC 7636
|
||||||
|
func GenerateCodeChallenge(verifier string) string {
|
||||||
|
hash := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64URLEncode(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// base64URLEncode encodes bytes to base64url without padding
|
||||||
|
func base64URLEncode(data []byte) string {
|
||||||
|
encoded := base64.URLEncoding.EncodeToString(data)
|
||||||
|
// Remove padding
|
||||||
|
return strings.TrimRight(encoded, "=")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||||
|
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||||
|
if redirectURI == "" {
|
||||||
|
redirectURI = DefaultRedirectURI
|
||||||
|
}
|
||||||
|
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("client_id", ClientID)
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("scope", DefaultScopes)
|
||||||
|
params.Set("state", state)
|
||||||
|
params.Set("code_challenge", codeChallenge)
|
||||||
|
params.Set("code_challenge_method", "S256")
|
||||||
|
// OpenAI specific parameters
|
||||||
|
params.Set("id_token_add_organizations", "true")
|
||||||
|
params.Set("codex_cli_simplified_flow", "true")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequest represents the token exchange request body
|
||||||
|
type TokenRequest struct {
|
||||||
|
GrantType string `json:"grant_type"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenResponse represents the token response from OpenAI OAuth
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenRequest represents the refresh token request
|
||||||
|
type RefreshTokenRequest struct {
|
||||||
|
GrantType string `json:"grant_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||||
|
type IDTokenClaims struct {
|
||||||
|
// Standard claims
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
Iat int64 `json:"iat"`
|
||||||
|
|
||||||
|
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||||
|
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||||
|
type OpenAIAuthClaims struct {
|
||||||
|
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||||
|
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Organizations []OrganizationClaim `json:"organizations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrganizationClaim represents an organization in the ID Token
|
||||||
|
type OrganizationClaim struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
IsDefault bool `json:"is_default"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||||
|
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||||
|
if redirectURI == "" {
|
||||||
|
redirectURI = DefaultRedirectURI
|
||||||
|
}
|
||||||
|
return &TokenRequest{
|
||||||
|
GrantType: "authorization_code",
|
||||||
|
ClientID: ClientID,
|
||||||
|
Code: code,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||||
|
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||||
|
return &RefreshTokenRequest{
|
||||||
|
GrantType: "refresh_token",
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
ClientID: ClientID,
|
||||||
|
Scope: RefreshScopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFormData converts TokenRequest to URL-encoded form data
|
||||||
|
func (r *TokenRequest) ToFormData() string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("grant_type", r.GrantType)
|
||||||
|
params.Set("client_id", r.ClientID)
|
||||||
|
params.Set("code", r.Code)
|
||||||
|
params.Set("redirect_uri", r.RedirectURI)
|
||||||
|
params.Set("code_verifier", r.CodeVerifier)
|
||||||
|
return params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||||
|
func (r *RefreshTokenRequest) ToFormData() string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("grant_type", r.GrantType)
|
||||||
|
params.Set("client_id", r.ClientID)
|
||||||
|
params.Set("refresh_token", r.RefreshToken)
|
||||||
|
params.Set("scope", r.Scope)
|
||||||
|
return params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||||
|
// Note: This does NOT verify the signature - it only decodes the payload
|
||||||
|
// For production, you should verify the token signature using OpenAI's public keys
|
||||||
|
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||||
|
parts := strings.Split(idToken, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode payload (second part)
|
||||||
|
payload := parts[1]
|
||||||
|
// Add padding if necessary
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
// Try standard encoding
|
||||||
|
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims IDTokenClaims
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractUserInfo extracts user information from ID Token claims
|
||||||
|
type UserInfo struct {
|
||||||
|
Email string
|
||||||
|
ChatGPTAccountID string
|
||||||
|
ChatGPTUserID string
|
||||||
|
UserID string
|
||||||
|
OrganizationID string
|
||||||
|
Organizations []OrganizationClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserInfo extracts user info from ID Token claims
|
||||||
|
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||||
|
info := &UserInfo{
|
||||||
|
Email: c.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.OpenAIAuth != nil {
|
||||||
|
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||||
|
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||||
|
info.UserID = c.OpenAIAuth.UserID
|
||||||
|
info.Organizations = c.OpenAIAuth.Organizations
|
||||||
|
|
||||||
|
// Get default organization ID
|
||||||
|
for _, org := range c.OpenAIAuth.Organizations {
|
||||||
|
if org.IsDefault {
|
||||||
|
info.OrganizationID = org.ID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If no default, use first org
|
||||||
|
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||||
|
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
18
backend/internal/pkg/openai/request.go
Normal file
18
backend/internal/pkg/openai/request.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||||
|
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||||
|
var CodexCLIUserAgentPrefixes = []string{
|
||||||
|
"codex_vscode/",
|
||||||
|
"codex_cli_rs/",
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||||
|
func IsCodexCLIRequest(userAgent string) bool {
|
||||||
|
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||||
|
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
42
backend/internal/pkg/pagination/pagination.go
Normal file
42
backend/internal/pkg/pagination/pagination.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package pagination
|
||||||
|
|
||||||
|
// PaginationParams 分页参数
|
||||||
|
type PaginationParams struct {
|
||||||
|
Page int
|
||||||
|
PageSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginationResult 分页结果
|
||||||
|
type PaginationResult struct {
|
||||||
|
Total int64
|
||||||
|
Page int
|
||||||
|
PageSize int
|
||||||
|
Pages int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultPagination 默认分页参数
|
||||||
|
func DefaultPagination() PaginationParams {
|
||||||
|
return PaginationParams{
|
||||||
|
Page: 1,
|
||||||
|
PageSize: 20,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset 计算偏移量
|
||||||
|
func (p PaginationParams) Offset() int {
|
||||||
|
if p.Page < 1 {
|
||||||
|
p.Page = 1
|
||||||
|
}
|
||||||
|
return (p.Page - 1) * p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit 获取限制数
|
||||||
|
func (p PaginationParams) Limit() int {
|
||||||
|
if p.PageSize < 1 {
|
||||||
|
return 20
|
||||||
|
}
|
||||||
|
if p.PageSize > 100 {
|
||||||
|
return 100
|
||||||
|
}
|
||||||
|
return p.PageSize
|
||||||
|
}
|
||||||
@@ -9,22 +9,22 @@ import (
|
|||||||
|
|
||||||
// Response 标准API响应格式
|
// Response 标准API响应格式
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data interface{} `json:"data,omitempty"`
|
Data any `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaginatedData 分页数据格式(匹配前端期望)
|
// PaginatedData 分页数据格式(匹配前端期望)
|
||||||
type PaginatedData struct {
|
type PaginatedData struct {
|
||||||
Items interface{} `json:"items"`
|
Items any `json:"items"`
|
||||||
Total int64 `json:"total"`
|
Total int64 `json:"total"`
|
||||||
Page int `json:"page"`
|
Page int `json:"page"`
|
||||||
PageSize int `json:"page_size"`
|
PageSize int `json:"page_size"`
|
||||||
Pages int `json:"pages"`
|
Pages int `json:"pages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success 返回成功响应
|
// Success 返回成功响应
|
||||||
func Success(c *gin.Context, data interface{}) {
|
func Success(c *gin.Context, data any) {
|
||||||
c.JSON(http.StatusOK, Response{
|
c.JSON(http.StatusOK, Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
Message: "success",
|
Message: "success",
|
||||||
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Created 返回创建成功响应
|
// Created 返回创建成功响应
|
||||||
func Created(c *gin.Context, data interface{}) {
|
func Created(c *gin.Context, data any) {
|
||||||
c.JSON(http.StatusCreated, Response{
|
c.JSON(http.StatusCreated, Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
Message: "success",
|
Message: "success",
|
||||||
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Paginated 返回分页数据
|
// Paginated 返回分页数据
|
||||||
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
|
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||||
if pages < 1 {
|
if pages < 1 {
|
||||||
pages = 1
|
pages = 1
|
||||||
@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaginationResult 分页结果(与repository.PaginationResult兼容)
|
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||||
type PaginationResult struct {
|
type PaginationResult struct {
|
||||||
Total int64
|
Total int64
|
||||||
Page int
|
Page int
|
||||||
@@ -99,7 +99,7 @@ type PaginationResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||||
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
|
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||||
if pagination == nil {
|
if pagination == nil {
|
||||||
Success(c, PaginatedData{
|
Success(c, PaginatedData{
|
||||||
Items: items,
|
Items: items,
|
||||||
|
|||||||
@@ -1,43 +1,39 @@
|
|||||||
package sysutil
|
package sysutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"os/exec"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serviceName = "sub2api"
|
// RestartService triggers a service restart by gracefully exiting.
|
||||||
|
|
||||||
// RestartService triggers a service restart via systemd.
|
|
||||||
//
|
//
|
||||||
// IMPORTANT: This function initiates the restart and returns immediately.
|
// This relies on systemd's Restart=always configuration to automatically
|
||||||
// The actual restart happens asynchronously - the current process will be killed
|
// restart the service after it exits. This is the industry-standard approach:
|
||||||
// by systemd and a new process will be started.
|
// - Simple and reliable
|
||||||
//
|
// - No sudo permissions needed
|
||||||
// We use Start() instead of Run() because:
|
// - No complex process management
|
||||||
// - systemctl restart will kill the current process first
|
// - Leverages systemd's native restart capability
|
||||||
// - Run() waits for completion, but the process dies before completion
|
|
||||||
// - Start() spawns the command independently, allowing systemd to handle the full cycle
|
|
||||||
//
|
//
|
||||||
// Prerequisites:
|
// Prerequisites:
|
||||||
// - Linux OS with systemd
|
// - Linux OS with systemd
|
||||||
// - NOPASSWD sudo access configured (install.sh creates /etc/sudoers.d/sub2api)
|
// - Service configured with Restart=always in systemd unit file
|
||||||
func RestartService() error {
|
func RestartService() error {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
return fmt.Errorf("systemd restart only available on Linux")
|
log.Println("Service restart via exit only works on Linux with systemd")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Initiating service restart...")
|
log.Println("Initiating service restart by graceful exit...")
|
||||||
|
log.Println("systemd will automatically restart the service (Restart=always)")
|
||||||
|
|
||||||
// The sub2api user has NOPASSWD sudo access for systemctl commands
|
// Give a moment for logs to flush and response to be sent
|
||||||
// (configured by install.sh in /etc/sudoers.d/sub2api).
|
go func() {
|
||||||
cmd := exec.Command("sudo", "systemctl", "restart", serviceName)
|
time.Sleep(100 * time.Millisecond)
|
||||||
if err := cmd.Start(); err != nil {
|
os.Exit(0)
|
||||||
return fmt.Errorf("failed to initiate service restart: %w", err)
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("Service restart initiated successfully")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
|||||||
|
|
||||||
func TestTimeNowAffected(t *testing.T) {
|
func TestTimeNowAffected(t *testing.T) {
|
||||||
// Reset to UTC first
|
// Reset to UTC first
|
||||||
Init("UTC")
|
if err := Init("UTC"); err != nil {
|
||||||
|
t.Fatalf("Init failed with UTC: %v", err)
|
||||||
|
}
|
||||||
utcNow := time.Now()
|
utcNow := time.Now()
|
||||||
|
|
||||||
// Switch to Shanghai (UTC+8)
|
// Switch to Shanghai (UTC+8)
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
shanghaiNow := time.Now()
|
shanghaiNow := time.Now()
|
||||||
|
|
||||||
// The times should be the same instant, but different timezone representation
|
// The times should be the same instant, but different timezone representation
|
||||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToday(t *testing.T) {
|
func TestToday(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
today := Today()
|
today := Today()
|
||||||
now := Now()
|
now := Now()
|
||||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStartOfDay(t *testing.T) {
|
func TestStartOfDay(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a time at 15:30:45
|
// Create a time at 15:30:45
|
||||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
|||||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||||
// and why StartOfDay is more reliable for timezone-aware code
|
// and why StartOfDay is more reliable for timezone-aware code
|
||||||
|
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
now := Now()
|
now := Now()
|
||||||
|
|
||||||
|
|||||||
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package usagestats
|
||||||
|
|
||||||
|
// AccountStats 账号使用统计
|
||||||
|
type AccountStats struct {
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
}
|
||||||
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package usagestats
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// DashboardStats 仪表盘统计
|
||||||
|
type DashboardStats struct {
|
||||||
|
// 用户统计
|
||||||
|
TotalUsers int64 `json:"total_users"`
|
||||||
|
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||||
|
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||||
|
|
||||||
|
// API Key 统计
|
||||||
|
TotalApiKeys int64 `json:"total_api_keys"`
|
||||||
|
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||||
|
|
||||||
|
// 账户统计
|
||||||
|
TotalAccounts int64 `json:"total_accounts"`
|
||||||
|
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||||
|
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||||
|
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||||
|
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||||
|
|
||||||
|
// 累计 Token 使用统计
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
|
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||||
|
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||||
|
|
||||||
|
// 今日 Token 使用统计
|
||||||
|
TodayRequests int64 `json:"today_requests"`
|
||||||
|
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||||
|
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||||
|
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||||
|
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||||
|
TodayTokens int64 `json:"today_tokens"`
|
||||||
|
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||||
|
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||||
|
|
||||||
|
// 系统运行统计
|
||||||
|
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||||
|
|
||||||
|
// 性能指标
|
||||||
|
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||||
|
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrendDataPoint represents a single point in trend data
|
||||||
|
type TrendDataPoint struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
InputTokens int64 `json:"input_tokens"`
|
||||||
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
|
CacheTokens int64 `json:"cache_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelStat represents usage statistics for a single model
|
||||||
|
type ModelStat struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
InputTokens int64 `json:"input_tokens"`
|
||||||
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserUsageTrendPoint represents user usage trend data point
|
||||||
|
type UserUsageTrendPoint struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||||
|
type ApiKeyUsageTrendPoint struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
ApiKeyID int64 `json:"api_key_id"`
|
||||||
|
KeyName string `json:"key_name"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserDashboardStats 用户仪表盘统计
|
||||||
|
type UserDashboardStats struct {
|
||||||
|
// API Key 统计
|
||||||
|
TotalApiKeys int64 `json:"total_api_keys"`
|
||||||
|
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||||
|
|
||||||
|
// 累计 Token 使用统计
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
|
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||||
|
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||||
|
|
||||||
|
// 今日 Token 使用统计
|
||||||
|
TodayRequests int64 `json:"today_requests"`
|
||||||
|
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||||
|
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||||
|
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||||
|
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||||
|
TodayTokens int64 `json:"today_tokens"`
|
||||||
|
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||||
|
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||||
|
|
||||||
|
// 性能统计
|
||||||
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
|
||||||
|
// 性能指标
|
||||||
|
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||||
|
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageLogFilters represents filters for usage log queries
|
||||||
|
type UsageLogFilters struct {
|
||||||
|
UserID int64
|
||||||
|
ApiKeyID int64
|
||||||
|
StartTime *time.Time
|
||||||
|
EndTime *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageStats represents usage statistics
|
||||||
|
type UsageStats struct {
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
|
type BatchUserUsageStats struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
TodayActualCost float64 `json:"today_actual_cost"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||||
|
type BatchApiKeyUsageStats struct {
|
||||||
|
ApiKeyID int64 `json:"api_key_id"`
|
||||||
|
TodayActualCost float64 `json:"today_actual_cost"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountUsageHistory represents daily usage history for an account
|
||||||
|
type AccountUsageHistory struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
ActualCost float64 `json:"actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountUsageSummary represents summary statistics for an account
|
||||||
|
type AccountUsageSummary struct {
|
||||||
|
Days int `json:"days"`
|
||||||
|
ActualDaysUsed int `json:"actual_days_used"`
|
||||||
|
TotalCost float64 `json:"total_cost"`
|
||||||
|
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||||
|
TotalRequests int64 `json:"total_requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||||
|
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||||
|
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||||
|
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||||
|
Today *struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
} `json:"today"`
|
||||||
|
HighestCostDay *struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
} `json:"highest_cost_day"`
|
||||||
|
HighestRequestDay *struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
} `json:"highest_request_day"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
|
type AccountUsageStatsResponse struct {
|
||||||
|
History []AccountUsageHistory `json:"history"`
|
||||||
|
Summary AccountUsageSummary `json:"summary"`
|
||||||
|
Models []ModelStat `json:"models"`
|
||||||
|
}
|
||||||
@@ -2,10 +2,14 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountRepository struct {
|
type AccountRepository struct {
|
||||||
@@ -22,14 +26,34 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
|
|||||||
|
|
||||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||||
var account model.Account
|
var account model.Account
|
||||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
|
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 填充 GroupIDs 虚拟字段
|
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||||
|
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
|
||||||
for _, ag := range account.AccountGroups {
|
for _, ag := range account.AccountGroups {
|
||||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||||
|
if ag.Group != nil {
|
||||||
|
account.Groups = append(account.Groups, ag.Group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||||
|
if crsAccountID == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var account model.Account
|
||||||
|
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
@@ -47,12 +71,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
|
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
|
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||||
var accounts []model.Account
|
var accounts []model.Account
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -77,15 +101,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 填充每个 Account 的 GroupIDs 虚拟字段
|
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||||
|
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
|
||||||
for _, ag := range accounts[i].AccountGroups {
|
for _, ag := range accounts[i].AccountGroups {
|
||||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||||
|
if ag.Group != nil {
|
||||||
|
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +122,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return accounts, &PaginationResult{
|
return accounts, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -130,7 +158,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
|||||||
|
|
||||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.StatusError,
|
"status": model.StatusError,
|
||||||
"error_message": errorMsg,
|
"error_message": errorMsg,
|
||||||
}).Error
|
}).Error
|
||||||
@@ -221,12 +249,44 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
|||||||
return accounts, err
|
return accounts, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||||
|
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||||
|
var accounts []model.Account
|
||||||
|
now := time.Now()
|
||||||
|
err := r.db.WithContext(ctx).
|
||||||
|
Where("platform = ?", platform).
|
||||||
|
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||||
|
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||||
|
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||||
|
Preload("Proxy").
|
||||||
|
Order("priority ASC").
|
||||||
|
Find(&accounts).Error
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||||
|
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||||
|
var accounts []model.Account
|
||||||
|
now := time.Now()
|
||||||
|
err := r.db.WithContext(ctx).
|
||||||
|
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||||
|
Where("account_groups.group_id = ?", groupID).
|
||||||
|
Where("accounts.platform = ?", platform).
|
||||||
|
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||||
|
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||||
|
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||||
|
Preload("Proxy").
|
||||||
|
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||||
|
Find(&accounts).Error
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
|
|
||||||
// SetRateLimited 标记账号为限流状态(429)
|
// SetRateLimited 标记账号为限流状态(429)
|
||||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"rate_limited_at": now,
|
"rate_limited_at": now,
|
||||||
"rate_limit_reset_at": resetAt,
|
"rate_limit_reset_at": resetAt,
|
||||||
}).Error
|
}).Error
|
||||||
}
|
}
|
||||||
@@ -240,7 +300,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
|||||||
// ClearRateLimit 清除账号的限流状态
|
// ClearRateLimit 清除账号的限流状态
|
||||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"rate_limited_at": nil,
|
"rate_limited_at": nil,
|
||||||
"rate_limit_reset_at": nil,
|
"rate_limit_reset_at": nil,
|
||||||
"overload_until": nil,
|
"overload_until": nil,
|
||||||
@@ -249,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
|||||||
|
|
||||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
updates := map[string]interface{}{
|
updates := map[string]any{
|
||||||
"session_window_status": status,
|
"session_window_status": status,
|
||||||
}
|
}
|
||||||
if start != nil {
|
if start != nil {
|
||||||
@@ -266,3 +326,75 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
|||||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
Update("schedulable", schedulable).Error
|
Update("schedulable", schedulable).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||||
|
// It merges the updates into existing Extra data without overwriting other fields
|
||||||
|
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current account to preserve existing Extra data
|
||||||
|
var account model.Account
|
||||||
|
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Extra if nil
|
||||||
|
if account.Extra == nil {
|
||||||
|
account.Extra = make(model.JSONB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge updates into existing Extra
|
||||||
|
for k, v := range updates {
|
||||||
|
account.Extra[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save updated Extra
|
||||||
|
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||||
|
Update("extra", account.Extra).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkUpdate updates multiple accounts with the provided fields.
|
||||||
|
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||||
|
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updateMap := map[string]any{}
|
||||||
|
|
||||||
|
if updates.Name != nil {
|
||||||
|
updateMap["name"] = *updates.Name
|
||||||
|
}
|
||||||
|
if updates.ProxyID != nil {
|
||||||
|
updateMap["proxy_id"] = updates.ProxyID
|
||||||
|
}
|
||||||
|
if updates.Concurrency != nil {
|
||||||
|
updateMap["concurrency"] = *updates.Concurrency
|
||||||
|
}
|
||||||
|
if updates.Priority != nil {
|
||||||
|
updateMap["priority"] = *updates.Priority
|
||||||
|
}
|
||||||
|
if updates.Status != nil {
|
||||||
|
updateMap["status"] = *updates.Status
|
||||||
|
}
|
||||||
|
if len(updates.Credentials) > 0 {
|
||||||
|
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
|
||||||
|
}
|
||||||
|
if len(updates.Extra) > 0 {
|
||||||
|
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(updateMap) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := r.db.WithContext(ctx).
|
||||||
|
Model(&model.Account{}).
|
||||||
|
Where("id IN ?", ids).
|
||||||
|
Clauses(clause.Returning{}).
|
||||||
|
Updates(updateMap)
|
||||||
|
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|||||||
51
backend/internal/repository/api_key_cache.go
Normal file
51
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||||
|
apiKeyRateLimitDuration = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||||
|
return &apiKeyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||||
|
return c.rdb.Incr(ctx, apiKey).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||||
|
}
|
||||||
@@ -2,7 +2,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||||
var keys []model.ApiKey
|
var keys []model.ApiKey
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return keys, &PaginationResult{
|
return keys, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
|||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||||
var keys []model.ApiKey
|
var keys []model.ApiKey
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return keys, &PaginationResult{
|
return keys, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
|
|||||||
174
backend/internal/repository/billing_cache.go
Normal file
174
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
billingBalanceKeyPrefix = "billing:balance:"
|
||||||
|
billingSubKeyPrefix = "billing:sub:"
|
||||||
|
billingCacheTTL = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
subFieldStatus = "status"
|
||||||
|
subFieldExpiresAt = "expires_at"
|
||||||
|
subFieldDailyUsage = "daily_usage"
|
||||||
|
subFieldWeeklyUsage = "weekly_usage"
|
||||||
|
subFieldMonthlyUsage = "monthly_usage"
|
||||||
|
subFieldVersion = "version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
deductBalanceScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current == false then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||||
|
redis.call('SET', KEYS[1], newVal)
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
updateSubUsageScript = redis.NewScript(`
|
||||||
|
local exists = redis.call('EXISTS', KEYS[1])
|
||||||
|
if exists == 0 then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local cost = tonumber(ARGV[1])
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||||
|
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type billingCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||||
|
return &billingCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(val, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
return c.parseSubscriptionCache(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||||
|
result := &ports.SubscriptionCacheData{}
|
||||||
|
|
||||||
|
result.Status = data[subFieldStatus]
|
||||||
|
if result.Status == "" {
|
||||||
|
return nil, errors.New("invalid cache: missing status")
|
||||||
|
}
|
||||||
|
|
||||||
|
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||||
|
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||||
|
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||||
|
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||||
|
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
if versionStr, ok := data[subFieldVersion]; ok {
|
||||||
|
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||||
|
if data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
fields := map[string]any{
|
||||||
|
subFieldStatus: data.Status,
|
||||||
|
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||||
|
subFieldDailyUsage: data.DailyUsage,
|
||||||
|
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||||
|
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||||
|
subFieldVersion: data.Version,
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.HSet(ctx, key, fields)
|
||||||
|
pipe.Expire(ctx, key, billingCacheTTL)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
228
backend/internal/repository/claude_oauth_service.go
Normal file
228
backend/internal/repository/claude_oauth_service.go
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeOAuthService struct{}
|
||||||
|
|
||||||
|
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||||
|
return &claudeOAuthService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
var orgs []struct {
|
||||||
|
UUID string `json:"uuid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL := "https://claude.ai/api/organizations"
|
||||||
|
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetCookies(&http.Cookie{
|
||||||
|
Name: "sessionKey",
|
||||||
|
Value: sessionKey,
|
||||||
|
}).
|
||||||
|
SetSuccessResult(&orgs).
|
||||||
|
Get(targetURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||||
|
return "", fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(orgs) == 0 {
|
||||||
|
return "", fmt.Errorf("no organizations found")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||||
|
return orgs[0].UUID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": oauth.ClientID,
|
||||||
|
"organization_uuid": orgUUID,
|
||||||
|
"redirect_uri": oauth.RedirectURI,
|
||||||
|
"scope": scope,
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": codeChallenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||||
|
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||||
|
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetCookies(&http.Cookie{
|
||||||
|
Name: "sessionKey",
|
||||||
|
Value: sessionKey,
|
||||||
|
}).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||||
|
SetHeader("Cache-Control", "no-cache").
|
||||||
|
SetHeader("Origin", "https://claude.ai").
|
||||||
|
SetHeader("Referer", "https://claude.ai/new").
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(reqBody).
|
||||||
|
SetSuccessResult(&result).
|
||||||
|
Post(authURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||||
|
return "", fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RedirectURI == "" {
|
||||||
|
return "", fmt.Errorf("no redirect_uri in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(result.RedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParams := parsedURL.Query()
|
||||||
|
authCode := queryParams.Get("code")
|
||||||
|
responseState := queryParams.Get("state")
|
||||||
|
|
||||||
|
if authCode == "" {
|
||||||
|
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||||
|
}
|
||||||
|
|
||||||
|
fullCode := authCode
|
||||||
|
if responseState != "" {
|
||||||
|
fullCode = authCode + "#" + responseState
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||||
|
return fullCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
// Parse code which may contain state in format "authCode#state"
|
||||||
|
authCode := code
|
||||||
|
codeState := ""
|
||||||
|
if idx := strings.Index(code, "#"); idx != -1 {
|
||||||
|
authCode = code[:idx]
|
||||||
|
codeState = code[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"code": authCode,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": oauth.ClientID,
|
||||||
|
"redirect_uri": oauth.RedirectURI,
|
||||||
|
"code_verifier": codeVerifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
if codeState != "" {
|
||||||
|
reqBody["state"] = codeState
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||||
|
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||||
|
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||||
|
|
||||||
|
var tokenResp oauth.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(reqBody).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(oauth.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "refresh_token")
|
||||||
|
formData.Set("refresh_token", refreshToken)
|
||||||
|
formData.Set("client_id", oauth.ClientID)
|
||||||
|
|
||||||
|
var tokenResp oauth.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetFormDataFromValues(formData).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(oauth.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createReqClient(proxyURL string) *req.Client {
|
||||||
|
client := req.C().
|
||||||
|
ImpersonateChrome().
|
||||||
|
SetTimeout(60 * time.Second)
|
||||||
|
|
||||||
|
if proxyURL != "" {
|
||||||
|
client.SetProxyURL(proxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
63
backend/internal/repository/claude_usage_service.go
Normal file
63
backend/internal/repository/claude_usage_service.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeUsageService struct{}
|
||||||
|
|
||||||
|
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||||
|
return &claudeUsageService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
|
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to get default transport")
|
||||||
|
}
|
||||||
|
transport = transport.Clone()
|
||||||
|
if proxyURL != "" {
|
||||||
|
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||||
|
transport.Proxy = http.ProxyURL(parsedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp service.ClaudeUsageResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &usageResp, nil
|
||||||
|
}
|
||||||
204
backend/internal/repository/concurrency_cache.go
Normal file
204
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Key prefixes for independent slot keys
|
||||||
|
// Format: concurrency:account:{accountID}:{requestID}
|
||||||
|
accountSlotKeyPrefix = "concurrency:account:"
|
||||||
|
// Format: concurrency:user:{userID}:{requestID}
|
||||||
|
userSlotKeyPrefix = "concurrency:user:"
|
||||||
|
// Wait queue keeps counter format: concurrency:wait:{userID}
|
||||||
|
waitQueueKeyPrefix = "concurrency:wait:"
|
||||||
|
|
||||||
|
// Slot TTL - each slot expires independently
|
||||||
|
slotTTL = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
|
||||||
|
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
|
||||||
|
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
|
||||||
|
// ARGV[1] = maxConcurrency
|
||||||
|
// ARGV[2] = TTL in seconds
|
||||||
|
acquireScript = redis.NewScript(`
|
||||||
|
local pattern = KEYS[1]
|
||||||
|
local slotKey = KEYS[2]
|
||||||
|
local maxConcurrency = tonumber(ARGV[1])
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
|
||||||
|
-- Count existing slots using SCAN
|
||||||
|
local cursor = "0"
|
||||||
|
local count = 0
|
||||||
|
repeat
|
||||||
|
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||||
|
cursor = result[1]
|
||||||
|
count = count + #result[2]
|
||||||
|
until cursor == "0"
|
||||||
|
|
||||||
|
-- Check if we can acquire a slot
|
||||||
|
if count < maxConcurrency then
|
||||||
|
redis.call('SET', slotKey, '1', 'EX', ttl)
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
|
||||||
|
return 0
|
||||||
|
`)
|
||||||
|
|
||||||
|
// getCountScript counts slots using SCAN
|
||||||
|
// KEYS[1] = pattern for SCAN
|
||||||
|
getCountScript = redis.NewScript(`
|
||||||
|
local pattern = KEYS[1]
|
||||||
|
local cursor = "0"
|
||||||
|
local count = 0
|
||||||
|
repeat
|
||||||
|
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||||
|
cursor = result[1]
|
||||||
|
count = count + #result[2]
|
||||||
|
until cursor == "0"
|
||||||
|
return count
|
||||||
|
`)
|
||||||
|
|
||||||
|
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||||
|
// KEYS[1] = wait queue key
|
||||||
|
// ARGV[1] = maxWait
|
||||||
|
// ARGV[2] = TTL in seconds
|
||||||
|
incrementWaitScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current == false then
|
||||||
|
current = 0
|
||||||
|
else
|
||||||
|
current = tonumber(current)
|
||||||
|
end
|
||||||
|
|
||||||
|
if current >= tonumber(ARGV[1]) then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
local newVal = redis.call('INCR', KEYS[1])
|
||||||
|
|
||||||
|
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||||
|
if newVal == 1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
|
end
|
||||||
|
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
// decrementWaitScript - same as before
|
||||||
|
decrementWaitScript = redis.NewScript(`
|
||||||
|
local current = redis.call('GET', KEYS[1])
|
||||||
|
if current ~= false and tonumber(current) > 0 then
|
||||||
|
redis.call('DECR', KEYS[1])
|
||||||
|
end
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type concurrencyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||||
|
return &concurrencyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for key generation
|
||||||
|
func accountSlotKey(accountID int64, requestID string) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountSlotPattern(accountID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func userSlotKey(userID int64, requestID string) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func userSlotPattern(userID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitQueueKey(userID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account slot operations
|
||||||
|
|
||||||
|
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
pattern := accountSlotPattern(accountID)
|
||||||
|
slotKey := accountSlotKey(accountID, requestID)
|
||||||
|
|
||||||
|
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
slotKey := accountSlotKey(accountID, requestID)
|
||||||
|
return c.rdb.Del(ctx, slotKey).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
pattern := accountSlotPattern(accountID)
|
||||||
|
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// User slot operations
|
||||||
|
|
||||||
|
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
pattern := userSlotPattern(userID)
|
||||||
|
slotKey := userSlotKey(userID, requestID)
|
||||||
|
|
||||||
|
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||||
|
slotKey := userSlotKey(userID, requestID)
|
||||||
|
return c.rdb.Del(ctx, slotKey).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
pattern := userSlotPattern(userID)
|
||||||
|
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait queue operations
|
||||||
|
|
||||||
|
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
key := waitQueueKey(userID)
|
||||||
|
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
key := waitQueueKey(userID)
|
||||||
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
48
backend/internal/repository/email_cache.go
Normal file
48
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const verifyCodeKeyPrefix = "verify_code:"
|
||||||
|
|
||||||
|
type emailCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||||
|
return &emailCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var data ports.VerificationCodeData
|
||||||
|
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
val, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||||
|
key := verifyCodeKeyPrefix + email
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
35
backend/internal/repository/gateway_cache.go
Normal file
35
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const stickySessionPrefix = "sticky_session:"
|
||||||
|
|
||||||
|
type gatewayCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||||
|
return &gatewayCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Get(ctx, key).Int64()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||||
|
key := stickySessionPrefix + sessionHash
|
||||||
|
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||||
|
}
|
||||||
116
backend/internal/repository/github_release_service.go
Normal file
116
backend/internal/repository/github_release_service.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type githubReleaseClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||||
|
return &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||||
|
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||||
|
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var release service.GitHubRelease
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Minute}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SECURITY: Check Content-Length if available
|
||||||
|
if resp.ContentLength > maxSize {
|
||||||
|
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := os.Create(dest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = out.Close() }()
|
||||||
|
|
||||||
|
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||||
|
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||||
|
written, err := io.Copy(out, limited)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we hit the limit (downloaded more than maxSize)
|
||||||
|
if written > maxSize {
|
||||||
|
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||||
|
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
@@ -2,7 +2,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
|
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
|
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||||
var groups []model.Group
|
var groups []model.Group
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return groups, &PaginationResult{
|
return groups, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
|
|||||||
67
backend/internal/repository/http_upstream.go
Normal file
67
backend/internal/repository/http_upstream.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||||
|
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||||
|
type httpUpstreamService struct {
|
||||||
|
defaultClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||||
|
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
|
||||||
|
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||||
|
if responseHeaderTimeout == 0 {
|
||||||
|
responseHeaderTimeout = 300 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &httpUpstreamService{
|
||||||
|
defaultClient: &http.Client{Transport: transport},
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||||
|
if proxyURL == "" {
|
||||||
|
return s.defaultClient.Do(req)
|
||||||
|
}
|
||||||
|
client := s.createProxyClient(proxyURL)
|
||||||
|
return client.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||||
|
parsedURL, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return s.defaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||||
|
if responseHeaderTimeout == 0 {
|
||||||
|
responseHeaderTimeout = 300 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(parsedURL),
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
47
backend/internal/repository/identity_cache.go
Normal file
47
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fingerprintKeyPrefix = "fingerprint:"
|
||||||
|
fingerprintTTL = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type identityCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||||
|
return &identityCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var fp ports.Fingerprint
|
||||||
|
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &fp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||||
|
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||||
|
val, err := json.Marshal(fp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||||
|
}
|
||||||
92
backend/internal/repository/openai_oauth_service.go
Normal file
92
backend/internal/repository/openai_oauth_service.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthService struct{}
|
||||||
|
|
||||||
|
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||||
|
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||||
|
return &openaiOAuthService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
client := createOpenAIReqClient(proxyURL)
|
||||||
|
|
||||||
|
if redirectURI == "" {
|
||||||
|
redirectURI = openai.DefaultRedirectURI
|
||||||
|
}
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "authorization_code")
|
||||||
|
formData.Set("client_id", openai.ClientID)
|
||||||
|
formData.Set("code", code)
|
||||||
|
formData.Set("redirect_uri", redirectURI)
|
||||||
|
formData.Set("code_verifier", codeVerifier)
|
||||||
|
|
||||||
|
var tokenResp openai.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetFormDataFromValues(formData).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(openai.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
client := createOpenAIReqClient(proxyURL)
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "refresh_token")
|
||||||
|
formData.Set("refresh_token", refreshToken)
|
||||||
|
formData.Set("client_id", openai.ClientID)
|
||||||
|
formData.Set("scope", openai.RefreshScopes)
|
||||||
|
|
||||||
|
var tokenResp openai.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetFormDataFromValues(formData).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(openai.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||||
|
client := req.C().
|
||||||
|
SetTimeout(60 * time.Second)
|
||||||
|
|
||||||
|
if proxyURL != "" {
|
||||||
|
client.SetProxyURL(proxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
73
backend/internal/repository/pricing_service.go
Normal file
73
backend/internal/repository/pricing_service.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pricingRemoteClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||||
|
return &pricingRemoteClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 哈希文件格式:hash filename 或者纯 hash
|
||||||
|
hash := strings.TrimSpace(string(body))
|
||||||
|
parts := strings.Fields(hash)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return parts[0], nil
|
||||||
|
}
|
||||||
|
return hash, nil
|
||||||
|
}
|
||||||
104
backend/internal/repository/proxy_probe_service.go
Normal file
104
backend/internal/repository/proxy_probe_service.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyProbeService struct{}
|
||||||
|
|
||||||
|
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||||
|
return &proxyProbeService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
|
transport, err := createProxyTransport(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
latencyMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipInfo struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
City string `json:"city"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Country string `json:"country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.ProxyExitInfo{
|
||||||
|
IP: ipInfo.IP,
|
||||||
|
City: ipInfo.City,
|
||||||
|
Region: ipInfo.Region,
|
||||||
|
Country: ipInfo.Country,
|
||||||
|
}, latencyMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||||
|
parsedURL, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parsedURL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
transport.Proxy = http.ProxyURL(parsedURL)
|
||||||
|
case "socks5":
|
||||||
|
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(network, addr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
@@ -2,7 +2,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
|
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
|
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||||
var proxies []model.Proxy
|
var proxies []model.Proxy
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxies, &PaginationResult{
|
return proxies, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
|
|||||||
49
backend/internal/repository/redeem_cache.go
Normal file
49
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||||
|
redeemLockKeyPrefix = "redeem:lock:"
|
||||||
|
redeemRateLimitDuration = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type redeemCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||||
|
return &redeemCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||||
|
return c.rdb.Get(ctx, key).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||||
|
key := redeemLockKeyPrefix + code
|
||||||
|
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||||
|
key := redeemLockKeyPrefix + code
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
@@ -2,7 +2,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
|
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
|
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
var codes []model.RedeemCode
|
var codes []model.RedeemCode
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return codes, &PaginationResult{
|
return codes, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -98,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.StatusUsed,
|
"status": model.StatusUsed,
|
||||||
"used_by": userID,
|
"used_by": userID,
|
||||||
"used_at": now,
|
"used_at": now,
|
||||||
|
|||||||
@@ -1,9 +1,5 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Repositories 所有仓库的集合
|
// Repositories 所有仓库的集合
|
||||||
type Repositories struct {
|
type Repositories struct {
|
||||||
User *UserRepository
|
User *UserRepository
|
||||||
@@ -16,59 +12,3 @@ type Repositories struct {
|
|||||||
Setting *SettingRepository
|
Setting *SettingRepository
|
||||||
UserSubscription *UserSubscriptionRepository
|
UserSubscription *UserSubscriptionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRepositories 创建所有仓库
|
|
||||||
func NewRepositories(db *gorm.DB) *Repositories {
|
|
||||||
return &Repositories{
|
|
||||||
User: NewUserRepository(db),
|
|
||||||
ApiKey: NewApiKeyRepository(db),
|
|
||||||
Group: NewGroupRepository(db),
|
|
||||||
Account: NewAccountRepository(db),
|
|
||||||
Proxy: NewProxyRepository(db),
|
|
||||||
RedeemCode: NewRedeemCodeRepository(db),
|
|
||||||
UsageLog: NewUsageLogRepository(db),
|
|
||||||
Setting: NewSettingRepository(db),
|
|
||||||
UserSubscription: NewUserSubscriptionRepository(db),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginationParams 分页参数
|
|
||||||
type PaginationParams struct {
|
|
||||||
Page int
|
|
||||||
PageSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginationResult 分页结果
|
|
||||||
type PaginationResult struct {
|
|
||||||
Total int64
|
|
||||||
Page int
|
|
||||||
PageSize int
|
|
||||||
Pages int
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultPagination 默认分页参数
|
|
||||||
func DefaultPagination() PaginationParams {
|
|
||||||
return PaginationParams{
|
|
||||||
Page: 1,
|
|
||||||
PageSize: 20,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset 计算偏移量
|
|
||||||
func (p PaginationParams) Offset() int {
|
|
||||||
if p.Page < 1 {
|
|
||||||
p.Page = 1
|
|
||||||
}
|
|
||||||
return (p.Page - 1) * p.PageSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit 获取限制数
|
|
||||||
func (p PaginationParams) Limit() int {
|
|
||||||
if p.PageSize < 1 {
|
|
||||||
return 20
|
|
||||||
}
|
|
||||||
if p.PageSize > 100 {
|
|
||||||
return 100
|
|
||||||
}
|
|
||||||
return p.PageSize
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|||||||
55
backend/internal/repository/turnstile_service.go
Normal file
55
backend/internal/repository/turnstile_service.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||||
|
|
||||||
|
type turnstileVerifier struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||||
|
return &turnstileVerifier{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("secret", secretKey)
|
||||||
|
formData.Set("response", token)
|
||||||
|
if remoteIP != "" {
|
||||||
|
formData.Set("remoteip", remoteIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := v.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
var result service.TurnstileVerifyResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
28
backend/internal/repository/update_cache.go
Normal file
28
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const updateCacheKey = "update:latest"
|
||||||
|
|
||||||
|
type updateCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||||
|
return &updateCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||||
|
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||||
|
}
|
||||||
@@ -2,8 +2,10 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -17,6 +19,30 @@ func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
|
|||||||
return &UsageLogRepository{db: db}
|
return &UsageLogRepository{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||||||
|
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
|
||||||
|
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||||||
|
var perfStats struct {
|
||||||
|
RequestCount int64 `gorm:"column:request_count"`
|
||||||
|
TokenCount int64 `gorm:"column:token_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||||
|
Select(`
|
||||||
|
COUNT(*) as request_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
|
||||||
|
`).
|
||||||
|
Where("created_at >= ?", fiveMinutesAgo)
|
||||||
|
|
||||||
|
if userID > 0 {
|
||||||
|
db = db.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Scan(&perfStats)
|
||||||
|
// 返回5分钟平均值
|
||||||
|
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
|
||||||
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
||||||
return r.db.WithContext(ctx).Create(log).Error
|
return r.db.WithContext(ctx).Create(log).Error
|
||||||
}
|
}
|
||||||
@@ -30,7 +56,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
|
|||||||
return &log, nil
|
return &log, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -49,7 +75,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return logs, &PaginationResult{
|
return logs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -57,7 +83,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -76,7 +102,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return logs, &PaginationResult{
|
return logs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -111,46 +137,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DashboardStats 仪表盘统计
|
// DashboardStats 仪表盘统计
|
||||||
type DashboardStats struct {
|
type DashboardStats = usagestats.DashboardStats
|
||||||
// 用户统计
|
|
||||||
TotalUsers int64 `json:"total_users"`
|
|
||||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
|
||||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
|
||||||
|
|
||||||
// API Key 统计
|
|
||||||
TotalApiKeys int64 `json:"total_api_keys"`
|
|
||||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
|
||||||
|
|
||||||
// 账户统计
|
|
||||||
TotalAccounts int64 `json:"total_accounts"`
|
|
||||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
|
||||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
|
||||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
|
||||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
|
||||||
|
|
||||||
// 累计 Token 使用统计
|
|
||||||
TotalRequests int64 `json:"total_requests"`
|
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
|
||||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
|
||||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
|
||||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
|
||||||
|
|
||||||
// 今日 Token 使用统计
|
|
||||||
TodayRequests int64 `json:"today_requests"`
|
|
||||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
|
||||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
|
||||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
|
||||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
|
||||||
TodayTokens int64 `json:"today_tokens"`
|
|
||||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
|
||||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
|
||||||
|
|
||||||
// 系统运行统计
|
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
var stats DashboardStats
|
var stats DashboardStats
|
||||||
@@ -267,10 +254,13 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
stats.TodayCost = todayStats.TodayCost
|
stats.TodayCost = todayStats.TodayCost
|
||||||
stats.TodayActualCost = todayStats.TodayActualCost
|
stats.TodayActualCost = todayStats.TodayActualCost
|
||||||
|
|
||||||
|
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
||||||
|
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, 0)
|
||||||
|
|
||||||
return &stats, nil
|
return &stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -289,7 +279,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return logs, &PaginationResult{
|
return logs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -297,7 +287,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
err := r.db.WithContext(ctx).
|
err := r.db.WithContext(ctx).
|
||||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||||
@@ -306,7 +296,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
|
|||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
err := r.db.WithContext(ctx).
|
err := r.db.WithContext(ctx).
|
||||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||||
@@ -315,7 +305,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
|
|||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
err := r.db.WithContext(ctx).
|
err := r.db.WithContext(ctx).
|
||||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||||
@@ -324,7 +314,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
|
|||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
err := r.db.WithContext(ctx).
|
err := r.db.WithContext(ctx).
|
||||||
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
||||||
@@ -337,15 +327,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountStats 账号使用统计
|
|
||||||
type AccountStats struct {
|
|
||||||
Requests int64 `json:"requests"`
|
|
||||||
Tokens int64 `json:"tokens"`
|
|
||||||
Cost float64 `json:"cost"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountTodayStats 获取账号今日统计
|
// GetAccountTodayStats 获取账号今日统计
|
||||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) {
|
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||||
today := timezone.Today()
|
today := timezone.Today()
|
||||||
|
|
||||||
var stats struct {
|
var stats struct {
|
||||||
@@ -367,7 +350,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AccountStats{
|
return &usagestats.AccountStats{
|
||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
@@ -375,7 +358,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) {
|
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||||
var stats struct {
|
var stats struct {
|
||||||
Requests int64 `gorm:"column:requests"`
|
Requests int64 `gorm:"column:requests"`
|
||||||
Tokens int64 `gorm:"column:tokens"`
|
Tokens int64 `gorm:"column:tokens"`
|
||||||
@@ -395,7 +378,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AccountStats{
|
return &usagestats.AccountStats{
|
||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
@@ -403,109 +386,16 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrendDataPoint represents a single point in trend data
|
// TrendDataPoint represents a single point in trend data
|
||||||
type TrendDataPoint struct {
|
type TrendDataPoint = usagestats.TrendDataPoint
|
||||||
Date string `json:"date"`
|
|
||||||
Requests int64 `json:"requests"`
|
|
||||||
InputTokens int64 `json:"input_tokens"`
|
|
||||||
OutputTokens int64 `json:"output_tokens"`
|
|
||||||
CacheTokens int64 `json:"cache_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelStat represents usage statistics for a single model
|
// ModelStat represents usage statistics for a single model
|
||||||
type ModelStat struct {
|
type ModelStat = usagestats.ModelStat
|
||||||
Model string `json:"model"`
|
|
||||||
Requests int64 `json:"requests"`
|
|
||||||
InputTokens int64 `json:"input_tokens"`
|
|
||||||
OutputTokens int64 `json:"output_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserUsageTrendPoint represents user usage trend data point
|
// UserUsageTrendPoint represents user usage trend data point
|
||||||
type UserUsageTrendPoint struct {
|
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||||
Date string `json:"date"`
|
|
||||||
UserID int64 `json:"user_id"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Requests int64 `json:"requests"`
|
|
||||||
Tokens int64 `json:"tokens"`
|
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUsageTrend returns usage trend data grouped by date
|
|
||||||
// granularity: "day" or "hour"
|
|
||||||
func (r *UsageLogRepository) GetUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
|
||||||
var results []TrendDataPoint
|
|
||||||
|
|
||||||
// Choose date format based on granularity
|
|
||||||
var dateFormat string
|
|
||||||
if granularity == "hour" {
|
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
|
||||||
} else {
|
|
||||||
dateFormat = "YYYY-MM-DD"
|
|
||||||
}
|
|
||||||
|
|
||||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
|
||||||
Select(`
|
|
||||||
TO_CHAR(created_at, ?) as date,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
|
||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
|
||||||
`, dateFormat).
|
|
||||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
|
||||||
Group("date").
|
|
||||||
Order("date ASC").
|
|
||||||
Scan(&results).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelStats returns usage statistics grouped by model
|
|
||||||
func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTime time.Time) ([]ModelStat, error) {
|
|
||||||
var results []ModelStat
|
|
||||||
|
|
||||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
|
||||||
Select(`
|
|
||||||
model,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
|
||||||
`).
|
|
||||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
|
||||||
Group("model").
|
|
||||||
Order("total_tokens DESC").
|
|
||||||
Scan(&results).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type ApiKeyUsageTrendPoint struct {
|
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||||
Date string `json:"date"`
|
|
||||||
ApiKeyID int64 `json:"api_key_id"`
|
|
||||||
KeyName string `json:"key_name"`
|
|
||||||
Requests int64 `json:"requests"`
|
|
||||||
Tokens int64 `json:"tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||||
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
||||||
@@ -598,34 +488,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UserDashboardStats 用户仪表盘统计
|
// UserDashboardStats 用户仪表盘统计
|
||||||
type UserDashboardStats struct {
|
type UserDashboardStats = usagestats.UserDashboardStats
|
||||||
// API Key 统计
|
|
||||||
TotalApiKeys int64 `json:"total_api_keys"`
|
|
||||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
|
||||||
|
|
||||||
// 累计 Token 使用统计
|
|
||||||
TotalRequests int64 `json:"total_requests"`
|
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
|
||||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
|
||||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
|
||||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
|
||||||
|
|
||||||
// 今日 Token 使用统计
|
|
||||||
TodayRequests int64 `json:"today_requests"`
|
|
||||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
|
||||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
|
||||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
|
||||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
|
||||||
TodayTokens int64 `json:"today_tokens"`
|
|
||||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
|
||||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
|
||||||
|
|
||||||
// 性能统计
|
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||||||
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||||||
@@ -708,6 +571,9 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
|||||||
stats.TodayCost = todayStats.TodayCost
|
stats.TodayCost = todayStats.TodayCost
|
||||||
stats.TodayActualCost = todayStats.TodayActualCost
|
stats.TodayActualCost = todayStats.TodayActualCost
|
||||||
|
|
||||||
|
// 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求)
|
||||||
|
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, userID)
|
||||||
|
|
||||||
return &stats, nil
|
return &stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -772,15 +638,10 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFilters represents filters for usage log queries
|
// UsageLogFilters represents filters for usage log queries
|
||||||
type UsageLogFilters struct {
|
type UsageLogFilters = usagestats.UsageLogFilters
|
||||||
UserID int64
|
|
||||||
ApiKeyID int64
|
|
||||||
StartTime *time.Time
|
|
||||||
EndTime *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) {
|
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||||
var logs []model.UsageLog
|
var logs []model.UsageLog
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -816,7 +677,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return logs, &PaginationResult{
|
return logs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -825,23 +686,10 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats struct {
|
type UsageStats = usagestats.UsageStats
|
||||||
TotalRequests int64 `json:"total_requests"`
|
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
|
||||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
TotalCost float64 `json:"total_cost"`
|
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
type BatchUserUsageStats struct {
|
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||||
UserID int64 `json:"user_id"`
|
|
||||||
TodayActualCost float64 `json:"today_actual_cost"`
|
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||||
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||||
@@ -901,11 +749,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||||
type BatchApiKeyUsageStats struct {
|
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||||
ApiKeyID int64 `json:"api_key_id"`
|
|
||||||
TodayActualCost float64 `json:"today_actual_cost"`
|
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||||
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||||
@@ -964,6 +808,79 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||||
|
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||||
|
var results []TrendDataPoint
|
||||||
|
|
||||||
|
var dateFormat string
|
||||||
|
if granularity == "hour" {
|
||||||
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
|
} else {
|
||||||
|
dateFormat = "YYYY-MM-DD"
|
||||||
|
}
|
||||||
|
|
||||||
|
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||||
|
Select(`
|
||||||
|
TO_CHAR(created_at, ?) as date,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
|
`, dateFormat).
|
||||||
|
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||||
|
|
||||||
|
if userID > 0 {
|
||||||
|
db = db.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
if apiKeyID > 0 {
|
||||||
|
db = db.Where("api_key_id = ?", apiKeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.Group("date").Order("date ASC").Scan(&results).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||||
|
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||||
|
var results []ModelStat
|
||||||
|
|
||||||
|
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||||
|
Select(`
|
||||||
|
model,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
|
`).
|
||||||
|
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||||
|
|
||||||
|
if userID > 0 {
|
||||||
|
db = db.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
if apiKeyID > 0 {
|
||||||
|
db = db.Where("api_key_id = ?", apiKeyID)
|
||||||
|
}
|
||||||
|
if accountID > 0 {
|
||||||
|
db = db.Where("account_id = ?", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetGlobalStats gets usage statistics for all users within a time range
|
// GetGlobalStats gets usage statistics for all users within a time range
|
||||||
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||||
var stats struct {
|
var stats struct {
|
||||||
@@ -1004,3 +921,169 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
|||||||
AverageDurationMs: stats.AverageDurationMs,
|
AverageDurationMs: stats.AverageDurationMs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountUsageHistory represents daily usage history for an account
|
||||||
|
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||||||
|
|
||||||
|
// AccountUsageSummary represents summary statistics for an account
|
||||||
|
type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||||
|
|
||||||
|
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||||
|
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||||
|
|
||||||
|
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||||
|
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||||
|
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||||
|
if daysCount <= 0 {
|
||||||
|
daysCount = 30
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get daily history
|
||||||
|
var historyResults []struct {
|
||||||
|
Date string `gorm:"column:date"`
|
||||||
|
Requests int64 `gorm:"column:requests"`
|
||||||
|
Tokens int64 `gorm:"column:tokens"`
|
||||||
|
Cost float64 `gorm:"column:cost"`
|
||||||
|
ActualCost float64 `gorm:"column:actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||||
|
Select(`
|
||||||
|
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||||
|
`).
|
||||||
|
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||||
|
Group("date").
|
||||||
|
Order("date ASC").
|
||||||
|
Scan(&historyResults).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build history with labels
|
||||||
|
history := make([]AccountUsageHistory, 0, len(historyResults))
|
||||||
|
for _, h := range historyResults {
|
||||||
|
// Parse date to get label (MM/DD)
|
||||||
|
t, _ := time.Parse("2006-01-02", h.Date)
|
||||||
|
label := t.Format("01/02")
|
||||||
|
history = append(history, AccountUsageHistory{
|
||||||
|
Date: h.Date,
|
||||||
|
Label: label,
|
||||||
|
Requests: h.Requests,
|
||||||
|
Tokens: h.Tokens,
|
||||||
|
Cost: h.Cost,
|
||||||
|
ActualCost: h.ActualCost,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate summary
|
||||||
|
var totalActualCost, totalStandardCost float64
|
||||||
|
var totalRequests, totalTokens int64
|
||||||
|
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||||
|
|
||||||
|
for i := range history {
|
||||||
|
h := &history[i]
|
||||||
|
totalActualCost += h.ActualCost
|
||||||
|
totalStandardCost += h.Cost
|
||||||
|
totalRequests += h.Requests
|
||||||
|
totalTokens += h.Tokens
|
||||||
|
|
||||||
|
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
|
||||||
|
highestCostDay = h
|
||||||
|
}
|
||||||
|
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
|
||||||
|
highestRequestDay = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actualDaysUsed := len(history)
|
||||||
|
if actualDaysUsed == 0 {
|
||||||
|
actualDaysUsed = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get average duration
|
||||||
|
var avgDuration struct {
|
||||||
|
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||||
|
}
|
||||||
|
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||||
|
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
|
||||||
|
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||||
|
Scan(&avgDuration)
|
||||||
|
|
||||||
|
summary := AccountUsageSummary{
|
||||||
|
Days: daysCount,
|
||||||
|
ActualDaysUsed: actualDaysUsed,
|
||||||
|
TotalCost: totalActualCost,
|
||||||
|
TotalStandardCost: totalStandardCost,
|
||||||
|
TotalRequests: totalRequests,
|
||||||
|
TotalTokens: totalTokens,
|
||||||
|
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
||||||
|
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||||
|
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||||
|
AvgDurationMs: avgDuration.AvgDurationMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set today's stats
|
||||||
|
todayStr := timezone.Now().Format("2006-01-02")
|
||||||
|
for i := range history {
|
||||||
|
if history[i].Date == todayStr {
|
||||||
|
summary.Today = &struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}{
|
||||||
|
Date: history[i].Date,
|
||||||
|
Cost: history[i].ActualCost,
|
||||||
|
Requests: history[i].Requests,
|
||||||
|
Tokens: history[i].Tokens,
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set highest cost day
|
||||||
|
if highestCostDay != nil {
|
||||||
|
summary.HighestCostDay = &struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
}{
|
||||||
|
Date: highestCostDay.Date,
|
||||||
|
Label: highestCostDay.Label,
|
||||||
|
Cost: highestCostDay.ActualCost,
|
||||||
|
Requests: highestCostDay.Requests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set highest request day
|
||||||
|
if highestRequestDay != nil {
|
||||||
|
summary.HighestRequestDay = &struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
}{
|
||||||
|
Date: highestRequestDay.Date,
|
||||||
|
Label: highestRequestDay.Label,
|
||||||
|
Requests: highestRequestDay.Requests,
|
||||||
|
Cost: highestRequestDay.ActualCost,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get model statistics using the unified method
|
||||||
|
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
||||||
|
if err != nil {
|
||||||
|
models = []ModelStat{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AccountUsageStatsResponse{
|
||||||
|
History: history,
|
||||||
|
Summary: summary,
|
||||||
|
Models: models,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
|
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
|
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||||
var users []model.User
|
var users []model.User
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -65,23 +66,53 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
|
|||||||
}
|
}
|
||||||
if search != "" {
|
if search != "" {
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
db = db.Where("email ILIKE ?", searchPattern)
|
db = db.Where(
|
||||||
|
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
|
||||||
|
searchPattern, searchPattern, searchPattern,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Count(&total).Error; err != nil {
|
if err := db.Count(&total).Error; err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Query users with pagination (reuse the same db with filters applied)
|
||||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batch load subscriptions for all users (avoid N+1)
|
||||||
|
if len(users) > 0 {
|
||||||
|
userIDs := make([]int64, len(users))
|
||||||
|
userMap := make(map[int64]*model.User, len(users))
|
||||||
|
for i := range users {
|
||||||
|
userIDs[i] = users[i].ID
|
||||||
|
userMap[users[i].ID] = &users[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query active subscriptions with groups in one query
|
||||||
|
var subscriptions []model.UserSubscription
|
||||||
|
if err := r.db.WithContext(ctx).
|
||||||
|
Preload("Group").
|
||||||
|
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
|
||||||
|
Find(&subscriptions).Error; err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate subscriptions with users
|
||||||
|
for i := range subscriptions {
|
||||||
|
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
||||||
|
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pages := int(total) / params.Limit()
|
pages := int(total) / params.Limit()
|
||||||
if int(total)%params.Limit() > 0 {
|
if int(total)%params.Limit() > 0 {
|
||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, &PaginationResult{
|
return users, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -128,3 +159,15 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
|||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||||
|
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := r.db.WithContext(ctx).
|
||||||
|
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
|
||||||
|
Order("id ASC").
|
||||||
|
First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListByGroupID 获取分组的所有订阅(分页)
|
// ListByGroupID 获取分组的所有订阅(分页)
|
||||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
|
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
var subs []model.UserSubscription
|
var subs []model.UserSubscription
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return subs, &PaginationResult{
|
return subs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 获取所有订阅(分页,支持筛选)
|
// List 获取所有订阅(分页,支持筛选)
|
||||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
|
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
var subs []model.UserSubscription
|
var subs []model.UserSubscription
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
|
|||||||
pages++
|
pages++
|
||||||
}
|
}
|
||||||
|
|
||||||
return subs, &PaginationResult{
|
return subs, &pagination.PaginationResult{
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: params.Page,
|
Page: params.Page,
|
||||||
PageSize: params.Limit(),
|
PageSize: params.Limit(),
|
||||||
@@ -184,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
|
|||||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||||
@@ -196,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
|||||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_usage_usd": 0,
|
"daily_usage_usd": 0,
|
||||||
"daily_window_start": newWindowStart,
|
"daily_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -207,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
|
|||||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"weekly_usage_usd": 0,
|
"weekly_usage_usd": 0,
|
||||||
"weekly_window_start": newWindowStart,
|
"weekly_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -218,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
|
|||||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"monthly_usage_usd": 0,
|
"monthly_usage_usd": 0,
|
||||||
"monthly_window_start": newWindowStart,
|
"monthly_window_start": newWindowStart,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
@@ -229,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
|||||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"daily_window_start": activateTime,
|
"daily_window_start": activateTime,
|
||||||
"weekly_window_start": activateTime,
|
"weekly_window_start": activateTime,
|
||||||
"monthly_window_start": activateTime,
|
"monthly_window_start": activateTime,
|
||||||
@@ -241,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
|
|||||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": status,
|
"status": status,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -251,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
|
|||||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"expires_at": newExpiresAt,
|
"expires_at": newExpiresAt,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -261,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
|
|||||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"notes": notes,
|
"notes": notes,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}).Error
|
}).Error
|
||||||
@@ -280,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
|
|||||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]any{
|
||||||
"status": model.SubscriptionStatusExpired,
|
"status": model.SubscriptionStatusExpired,
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
})
|
})
|
||||||
|
|||||||
52
backend/internal/repository/wire.go
Normal file
52
backend/internal/repository/wire.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderSet is the Wire provider set for all repositories
|
||||||
|
var ProviderSet = wire.NewSet(
|
||||||
|
NewUserRepository,
|
||||||
|
NewApiKeyRepository,
|
||||||
|
NewGroupRepository,
|
||||||
|
NewAccountRepository,
|
||||||
|
NewProxyRepository,
|
||||||
|
NewRedeemCodeRepository,
|
||||||
|
NewUsageLogRepository,
|
||||||
|
NewSettingRepository,
|
||||||
|
NewUserSubscriptionRepository,
|
||||||
|
wire.Struct(new(Repositories), "*"),
|
||||||
|
|
||||||
|
// Cache implementations
|
||||||
|
NewGatewayCache,
|
||||||
|
NewBillingCache,
|
||||||
|
NewApiKeyCache,
|
||||||
|
NewConcurrencyCache,
|
||||||
|
NewEmailCache,
|
||||||
|
NewIdentityCache,
|
||||||
|
NewRedeemCache,
|
||||||
|
NewUpdateCache,
|
||||||
|
|
||||||
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
|
NewTurnstileVerifier,
|
||||||
|
NewPricingRemoteClient,
|
||||||
|
NewGitHubReleaseClient,
|
||||||
|
NewProxyExitInfoProber,
|
||||||
|
NewClaudeUsageFetcher,
|
||||||
|
NewClaudeOAuthClient,
|
||||||
|
NewHTTPUpstream,
|
||||||
|
NewOpenAIOAuthClient,
|
||||||
|
|
||||||
|
// Bind concrete repositories to service port interfaces
|
||||||
|
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||||
|
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||||
|
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
|
||||||
|
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
|
||||||
|
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
|
||||||
|
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
|
||||||
|
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
|
||||||
|
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
|
||||||
|
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
|
||||||
|
)
|
||||||
45
backend/internal/server/http.go
Normal file
45
backend/internal/server/http.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderSet 提供服务器层的依赖
|
||||||
|
var ProviderSet = wire.NewSet(
|
||||||
|
ProvideRouter,
|
||||||
|
ProvideHTTPServer,
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProvideRouter 提供路由器
|
||||||
|
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||||
|
if cfg.Server.Mode == "release" {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(gin.Recovery())
|
||||||
|
|
||||||
|
return SetupRouter(r, cfg, handlers, services, repos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideHTTPServer 提供 HTTP 服务器
|
||||||
|
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||||
|
return &http.Server{
|
||||||
|
Addr: cfg.Server.Address(),
|
||||||
|
Handler: router,
|
||||||
|
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||||
|
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||||
|
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||||
|
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||||
|
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||||
|
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||||
|
}
|
||||||
|
}
|
||||||
312
backend/internal/server/router.go
Normal file
312
backend/internal/server/router.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupRouter 配置路由器中间件和路由
|
||||||
|
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||||
|
// 应用中间件
|
||||||
|
r.Use(middleware.Logger())
|
||||||
|
r.Use(middleware.CORS())
|
||||||
|
|
||||||
|
// 注册路由
|
||||||
|
registerRoutes(r, handlers, services, repos)
|
||||||
|
|
||||||
|
// Serve embedded frontend if available
|
||||||
|
if web.HasEmbeddedFrontend() {
|
||||||
|
r.Use(web.ServeEmbeddedFrontend())
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerRoutes 注册所有 HTTP 路由
|
||||||
|
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
||||||
|
// 健康检查
|
||||||
|
r.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Claude Code 遥测日志(忽略,直接返回200)
|
||||||
|
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
||||||
|
// This is used by the frontend to detect when the service has restarted after setup
|
||||||
|
r.GET("/setup/status", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": 0,
|
||||||
|
"data": gin.H{
|
||||||
|
"needs_setup": false,
|
||||||
|
"step": "completed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// API v1
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
{
|
||||||
|
// 公开接口
|
||||||
|
auth := v1.Group("/auth")
|
||||||
|
{
|
||||||
|
auth.POST("/register", h.Auth.Register)
|
||||||
|
auth.POST("/login", h.Auth.Login)
|
||||||
|
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公开设置(无需认证)
|
||||||
|
settings := v1.Group("/settings")
|
||||||
|
{
|
||||||
|
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 需要认证的接口
|
||||||
|
authenticated := v1.Group("")
|
||||||
|
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
||||||
|
{
|
||||||
|
// 当前用户信息
|
||||||
|
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||||
|
|
||||||
|
// 用户接口
|
||||||
|
user := authenticated.Group("/user")
|
||||||
|
{
|
||||||
|
user.GET("/profile", h.User.GetProfile)
|
||||||
|
user.PUT("/password", h.User.ChangePassword)
|
||||||
|
user.PUT("", h.User.UpdateProfile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// API Key管理
|
||||||
|
keys := authenticated.Group("/keys")
|
||||||
|
{
|
||||||
|
keys.GET("", h.APIKey.List)
|
||||||
|
keys.GET("/:id", h.APIKey.GetByID)
|
||||||
|
keys.POST("", h.APIKey.Create)
|
||||||
|
keys.PUT("/:id", h.APIKey.Update)
|
||||||
|
keys.DELETE("/:id", h.APIKey.Delete)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户可用分组(非管理员接口)
|
||||||
|
groups := authenticated.Group("/groups")
|
||||||
|
{
|
||||||
|
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用记录
|
||||||
|
usage := authenticated.Group("/usage")
|
||||||
|
{
|
||||||
|
usage.GET("", h.Usage.List)
|
||||||
|
usage.GET("/:id", h.Usage.GetByID)
|
||||||
|
usage.GET("/stats", h.Usage.Stats)
|
||||||
|
// User dashboard endpoints
|
||||||
|
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||||
|
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||||
|
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||||
|
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 卡密兑换
|
||||||
|
redeem := authenticated.Group("/redeem")
|
||||||
|
{
|
||||||
|
redeem.POST("", h.Redeem.Redeem)
|
||||||
|
redeem.GET("/history", h.Redeem.GetHistory)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户订阅
|
||||||
|
subscriptions := authenticated.Group("/subscriptions")
|
||||||
|
{
|
||||||
|
subscriptions.GET("", h.Subscription.List)
|
||||||
|
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||||
|
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||||
|
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理员接口
|
||||||
|
admin := v1.Group("/admin")
|
||||||
|
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
|
||||||
|
{
|
||||||
|
// 仪表盘
|
||||||
|
dashboard := admin.Group("/dashboard")
|
||||||
|
{
|
||||||
|
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||||
|
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||||
|
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||||
|
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||||
|
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||||
|
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||||
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户管理
|
||||||
|
users := admin.Group("/users")
|
||||||
|
{
|
||||||
|
users.GET("", h.Admin.User.List)
|
||||||
|
users.GET("/:id", h.Admin.User.GetByID)
|
||||||
|
users.POST("", h.Admin.User.Create)
|
||||||
|
users.PUT("/:id", h.Admin.User.Update)
|
||||||
|
users.DELETE("/:id", h.Admin.User.Delete)
|
||||||
|
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||||
|
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||||
|
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组管理
|
||||||
|
groups := admin.Group("/groups")
|
||||||
|
{
|
||||||
|
groups.GET("", h.Admin.Group.List)
|
||||||
|
groups.GET("/all", h.Admin.Group.GetAll)
|
||||||
|
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||||
|
groups.POST("", h.Admin.Group.Create)
|
||||||
|
groups.PUT("/:id", h.Admin.Group.Update)
|
||||||
|
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||||
|
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||||
|
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 账号管理
|
||||||
|
accounts := admin.Group("/accounts")
|
||||||
|
{
|
||||||
|
accounts.GET("", h.Admin.Account.List)
|
||||||
|
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||||
|
accounts.POST("", h.Admin.Account.Create)
|
||||||
|
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||||
|
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||||
|
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||||
|
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||||
|
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||||
|
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||||
|
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||||
|
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||||
|
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||||
|
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||||
|
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||||
|
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||||
|
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||||
|
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||||
|
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||||
|
|
||||||
|
// Claude OAuth routes
|
||||||
|
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||||
|
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||||
|
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||||
|
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||||
|
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||||
|
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI OAuth routes
|
||||||
|
openai := admin.Group("/openai")
|
||||||
|
{
|
||||||
|
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||||
|
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||||
|
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||||
|
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||||
|
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 代理管理
|
||||||
|
proxies := admin.Group("/proxies")
|
||||||
|
{
|
||||||
|
proxies.GET("", h.Admin.Proxy.List)
|
||||||
|
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||||
|
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||||
|
proxies.POST("", h.Admin.Proxy.Create)
|
||||||
|
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||||
|
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||||
|
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||||
|
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||||
|
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||||
|
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 卡密管理
|
||||||
|
codes := admin.Group("/redeem-codes")
|
||||||
|
{
|
||||||
|
codes.GET("", h.Admin.Redeem.List)
|
||||||
|
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||||
|
codes.GET("/export", h.Admin.Redeem.Export)
|
||||||
|
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||||
|
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||||
|
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||||
|
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||||
|
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 系统设置
|
||||||
|
adminSettings := admin.Group("/settings")
|
||||||
|
{
|
||||||
|
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||||
|
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||||
|
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||||
|
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||||
|
// Admin API Key 管理
|
||||||
|
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||||
|
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||||
|
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 系统管理
|
||||||
|
system := admin.Group("/system")
|
||||||
|
{
|
||||||
|
system.GET("/version", h.Admin.System.GetVersion)
|
||||||
|
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||||
|
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||||
|
system.POST("/rollback", h.Admin.System.Rollback)
|
||||||
|
system.POST("/restart", h.Admin.System.RestartService)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订阅管理
|
||||||
|
subscriptions := admin.Group("/subscriptions")
|
||||||
|
{
|
||||||
|
subscriptions.GET("", h.Admin.Subscription.List)
|
||||||
|
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||||
|
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||||
|
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||||
|
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||||
|
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||||
|
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组下的订阅列表
|
||||||
|
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||||
|
|
||||||
|
// 用户下的订阅列表
|
||||||
|
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||||
|
|
||||||
|
// 使用记录管理
|
||||||
|
usage := admin.Group("/usage")
|
||||||
|
{
|
||||||
|
usage.GET("", h.Admin.Usage.List)
|
||||||
|
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||||
|
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||||
|
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API网关(Claude API兼容)
|
||||||
|
gateway := r.Group("/v1")
|
||||||
|
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
||||||
|
{
|
||||||
|
gateway.POST("/messages", h.Gateway.Messages)
|
||||||
|
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||||
|
gateway.GET("/models", h.Gateway.Models)
|
||||||
|
gateway.GET("/usage", h.Gateway.Usage)
|
||||||
|
// OpenAI Responses API
|
||||||
|
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI Responses API(不带v1前缀的别名)
|
||||||
|
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
|
||||||
|
}
|
||||||
@@ -4,8 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -16,37 +17,37 @@ var (
|
|||||||
|
|
||||||
// CreateAccountRequest 创建账号请求
|
// CreateAccountRequest 创建账号请求
|
||||||
type CreateAccountRequest struct {
|
type CreateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Credentials map[string]interface{} `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]interface{} `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest 更新账号请求
|
// UpdateAccountRequest 更新账号请求
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name *string `json:"name"`
|
Name *string `json:"name"`
|
||||||
Credentials *map[string]interface{} `json:"credentials"`
|
Credentials *map[string]any `json:"credentials"`
|
||||||
Extra *map[string]interface{} `json:"extra"`
|
Extra *map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
Status *string `json:"status"`
|
Status *string `json:"status"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountService 账号管理服务
|
// AccountService 账号管理服务
|
||||||
type AccountService struct {
|
type AccountService struct {
|
||||||
accountRepo *repository.AccountRepository
|
accountRepo ports.AccountRepository
|
||||||
groupRepo *repository.GroupRepository
|
groupRepo ports.GroupRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountService 创建账号服务实例
|
// NewAccountService 创建账号服务实例
|
||||||
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
|
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
|
||||||
return &AccountService{
|
return &AccountService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 获取账号列表
|
// List 获取账号列表
|
||||||
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
|
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||||
|
|||||||
@@ -10,20 +10,23 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||||
testModel = "claude-sonnet-4-5-20250929"
|
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||||
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
@@ -37,39 +40,46 @@ type TestEvent struct {
|
|||||||
|
|
||||||
// AccountTestService handles account testing operations
|
// AccountTestService handles account testing operations
|
||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
repos *repository.Repositories
|
accountRepo ports.AccountRepository
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
httpClient *http.Client
|
openaiOAuthService *OpenAIOAuthService
|
||||||
|
httpUpstream ports.HTTPUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountTestService creates a new AccountTestService
|
// NewAccountTestService creates a new AccountTestService
|
||||||
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
|
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
|
||||||
return &AccountTestService{
|
return &AccountTestService{
|
||||||
repos: repos,
|
accountRepo: accountRepo,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
httpClient: &http.Client{
|
openaiOAuthService: openaiOAuthService,
|
||||||
Timeout: 60 * time.Second,
|
httpUpstream: httpUpstream,
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionString generates a Claude Code style session string
|
// generateSessionString generates a Claude Code style session string
|
||||||
func generateSessionString() string {
|
func generateSessionString() (string, error) {
|
||||||
bytes := make([]byte, 32)
|
bytes := make([]byte, 32)
|
||||||
rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
hex64 := hex.EncodeToString(bytes)
|
hex64 := hex.EncodeToString(bytes)
|
||||||
sessionUUID := uuid.New().String()
|
sessionUUID := uuid.New().String()
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
|
// createTestPayload creates a Claude Code style test request payload
|
||||||
func createTestPayload() map[string]interface{} {
|
func createTestPayload(modelID string) (map[string]any, error) {
|
||||||
return map[string]interface{}{
|
sessionID, err := generateSessionString()
|
||||||
"model": testModel,
|
if err != nil {
|
||||||
"messages": []map[string]interface{}{
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"model": modelID,
|
||||||
|
"messages": []map[string]any{
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": []map[string]interface{}{
|
"content": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "hi",
|
"text": "hi",
|
||||||
@@ -80,7 +90,7 @@ func createTestPayload() map[string]interface{} {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"system": []map[string]interface{}{
|
"system": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||||
@@ -90,47 +100,62 @@ func createTestPayload() map[string]interface{} {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"metadata": map[string]string{
|
"metadata": map[string]string{
|
||||||
"user_id": generateSessionString(),
|
"user_id": sessionID,
|
||||||
},
|
},
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"temperature": 1,
|
"temperature": 1,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
}
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
|
|
||||||
func createApiKeyTestPayload(model string) map[string]interface{} {
|
|
||||||
return map[string]interface{}{
|
|
||||||
"model": model,
|
|
||||||
"messages": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "hi",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"stream": true,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAccountConnection tests an account's connection by sending a test request
|
// TestAccountConnection tests an account's connection by sending a test request
|
||||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
|
// All account types use full Claude Code client characteristics, only auth header differs
|
||||||
|
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||||
|
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Get account
|
// Get account
|
||||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, "Account not found")
|
return s.sendErrorAndEnd(c, "Account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine authentication method based on account type
|
// Route to platform-specific test method
|
||||||
|
if account.IsOpenAI() {
|
||||||
|
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.testClaudeAccountConnection(c, account, modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||||
|
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Determine the model to use
|
||||||
|
testModelID := modelID
|
||||||
|
if testModelID == "" {
|
||||||
|
testModelID = claude.DefaultTestModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// For API Key accounts with model mapping, map the model
|
||||||
|
if account.Type == "apikey" {
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if len(mapping) > 0 {
|
||||||
|
if mappedModel, exists := mapping[testModelID]; exists {
|
||||||
|
testModelID = mappedModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine authentication method and API URL
|
||||||
var authToken string
|
var authToken string
|
||||||
var authType string // "bearer" for OAuth, "apikey" for API Key
|
var useBearer bool
|
||||||
var apiURL string
|
var apiURL string
|
||||||
|
|
||||||
if account.IsOAuth() {
|
if account.IsOAuth() {
|
||||||
// OAuth or Setup Token account
|
// OAuth or Setup Token - use Bearer token
|
||||||
authType = "bearer"
|
useBearer = true
|
||||||
apiURL = testClaudeAPIURL
|
apiURL = testClaudeAPIURL
|
||||||
authToken = account.GetCredential("access_token")
|
authToken = account.GetCredential("access_token")
|
||||||
if authToken == "" {
|
if authToken == "" {
|
||||||
@@ -141,7 +166,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
needRefresh := false
|
needRefresh := false
|
||||||
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
|
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
|
||||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||||
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
|
if err == nil && time.Now().Unix()+300 > expiresAt {
|
||||||
needRefresh = true
|
needRefresh = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,19 +179,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
authToken = tokenInfo.AccessToken
|
authToken = tokenInfo.AccessToken
|
||||||
}
|
}
|
||||||
} else if account.Type == "apikey" {
|
} else if account.Type == "apikey" {
|
||||||
// API Key account
|
// API Key - use x-api-key header
|
||||||
authType = "apikey"
|
useBearer = false
|
||||||
authToken = account.GetCredential("api_key")
|
authToken = account.GetCredential("api_key")
|
||||||
if authToken == "" {
|
if authToken == "" {
|
||||||
return s.sendErrorAndEnd(c, "No API key available")
|
return s.sendErrorAndEnd(c, "No API key available")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get base URL (use default if not set)
|
|
||||||
apiURL = account.GetBaseURL()
|
apiURL = account.GetBaseURL()
|
||||||
if apiURL == "" {
|
if apiURL == "" {
|
||||||
apiURL = "https://api.anthropic.com"
|
apiURL = "https://api.anthropic.com"
|
||||||
}
|
}
|
||||||
// Append /v1/messages endpoint
|
|
||||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||||
} else {
|
} else {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||||
@@ -179,61 +202,49 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
// Create test request payload
|
// Create Claude Code style payload (same for all account types)
|
||||||
var payload map[string]interface{}
|
payload, err := createTestPayload(testModelID)
|
||||||
var actualModel string
|
if err != nil {
|
||||||
if authType == "apikey" {
|
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||||
// Use simpler payload for API Key (without Claude Code specific fields)
|
|
||||||
// Apply model mapping if configured
|
|
||||||
actualModel = account.GetMappedModel(testModel)
|
|
||||||
payload = createApiKeyTestPayload(actualModel)
|
|
||||||
} else {
|
|
||||||
actualModel = testModel
|
|
||||||
payload = createTestPayload()
|
|
||||||
}
|
}
|
||||||
payloadBytes, _ := json.Marshal(payload)
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
|
||||||
// Send test_start event with model info
|
// Send test_start event
|
||||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers based on auth type
|
// Set common headers
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
|
||||||
|
|
||||||
if authType == "bearer" {
|
// Apply Claude Code client headers
|
||||||
|
for key, value := range claude.DefaultHeaders {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set authentication header
|
||||||
|
if useBearer {
|
||||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
|
|
||||||
} else {
|
} else {
|
||||||
// API Key uses x-api-key header
|
|
||||||
req.Header.Set("x-api-key", authToken)
|
req.Header.Set("x-api-key", authToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure proxy if account has one
|
// Get proxy URL
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL := account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||||
Transport: transport,
|
|
||||||
Timeout: 60 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -241,18 +252,161 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process SSE stream
|
// Process SSE stream
|
||||||
return s.processStream(c, resp.Body)
|
return s.processClaudeStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processStream processes the SSE stream from Claude API
|
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||||
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
|
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Default to openai.DefaultTestModel for OpenAI testing
|
||||||
|
testModelID := modelID
|
||||||
|
if testModelID == "" {
|
||||||
|
testModelID = openai.DefaultTestModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// For API Key accounts with model mapping, map the model
|
||||||
|
if account.Type == "apikey" {
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if len(mapping) > 0 {
|
||||||
|
if mappedModel, exists := mapping[testModelID]; exists {
|
||||||
|
testModelID = mappedModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine authentication method and API URL
|
||||||
|
var authToken string
|
||||||
|
var apiURL string
|
||||||
|
var isOAuth bool
|
||||||
|
var chatgptAccountID string
|
||||||
|
|
||||||
|
if account.IsOAuth() {
|
||||||
|
isOAuth = true
|
||||||
|
// OAuth - use Bearer token with ChatGPT internal API
|
||||||
|
authToken = account.GetOpenAIAccessToken()
|
||||||
|
if authToken == "" {
|
||||||
|
return s.sendErrorAndEnd(c, "No access token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if token is expired and refresh if needed
|
||||||
|
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
||||||
|
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||||
|
}
|
||||||
|
authToken = tokenInfo.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth uses ChatGPT internal API
|
||||||
|
apiURL = chatgptCodexAPIURL
|
||||||
|
chatgptAccountID = account.GetChatGPTAccountID()
|
||||||
|
} else if account.Type == "apikey" {
|
||||||
|
// API Key - use Platform API
|
||||||
|
authToken = account.GetOpenAIApiKey()
|
||||||
|
if authToken == "" {
|
||||||
|
return s.sendErrorAndEnd(c, "No API key available")
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := account.GetOpenAIBaseURL()
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com"
|
||||||
|
}
|
||||||
|
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
|
||||||
|
} else {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set SSE headers
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Create OpenAI Responses API payload
|
||||||
|
payload := createOpenAITestPayload(testModelID, isOAuth)
|
||||||
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
|
||||||
|
// Send test_start event
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set common headers
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
|
||||||
|
// Set OAuth-specific headers for ChatGPT internal API
|
||||||
|
if isOAuth {
|
||||||
|
req.Host = "chatgpt.com"
|
||||||
|
req.Header.Set("accept", "text/event-stream")
|
||||||
|
if chatgptAccountID != "" {
|
||||||
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get proxy URL
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process SSE stream
|
||||||
|
return s.processOpenAIStream(c, resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||||
|
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||||
|
payload := map[string]any{
|
||||||
|
"model": modelID,
|
||||||
|
"input": []map[string]any{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "input_text",
|
||||||
|
"text": "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth accounts using ChatGPT internal API require store: false
|
||||||
|
if isOAuth {
|
||||||
|
payload["store"] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// All accounts require instructions for Responses API
|
||||||
|
payload["instructions"] = openai.DefaultInstructions
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClaudeStream processes the SSE stream from Claude API
|
||||||
|
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||||
reader := bufio.NewReader(body)
|
reader := bufio.NewReader(body)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadString('\n')
|
line, err := reader.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// Stream ended, send complete event
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -270,7 +424,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var data map[string]interface{}
|
var data map[string]any
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -279,7 +433,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
|
|
||||||
switch eventType {
|
switch eventType {
|
||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
if delta, ok := data["delta"].(map[string]interface{}); ok {
|
if delta, ok := data["delta"].(map[string]any); ok {
|
||||||
if text, ok := delta["text"].(string); ok {
|
if text, ok := delta["text"].(string); ok {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
}
|
}
|
||||||
@@ -289,7 +443,60 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
return nil
|
return nil
|
||||||
case "error":
|
case "error":
|
||||||
errorMsg := "Unknown error"
|
errorMsg := "Unknown error"
|
||||||
if errData, ok := data["error"].(map[string]interface{}); ok {
|
if errData, ok := data["error"].(map[string]any); ok {
|
||||||
|
if msg, ok := errData["message"].(string); ok {
|
||||||
|
errorMsg = msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.sendErrorAndEnd(c, errorMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||||
|
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||||
|
reader := bufio.NewReader(body)
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||||
|
if jsonStr == "[DONE]" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventType, _ := data["type"].(string)
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
|
case "response.output_text.delta":
|
||||||
|
// OpenAI Responses API uses "delta" field for text content
|
||||||
|
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
|
return nil
|
||||||
|
case "error":
|
||||||
|
errorMsg := "Unknown error"
|
||||||
|
if errData, ok := data["error"].(map[string]any); ok {
|
||||||
if msg, ok := errData["message"].(string); ok {
|
if msg, ok := errData["message"].(string); ok {
|
||||||
errorMsg = msg
|
errorMsg = msg
|
||||||
}
|
}
|
||||||
@@ -302,7 +509,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
// sendEvent sends a SSE event to the client
|
// sendEvent sends a SSE event to the client
|
||||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||||
eventJSON, _ := json.Marshal(event)
|
eventJSON, _ := json.Marshal(event)
|
||||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||||
|
log.Printf("failed to write SSE event: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,5 +520,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
|||||||
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
|
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
|
||||||
log.Printf("Account test error: %s", errorMsg)
|
log.Printf("Account test error: %s", errorMsg)
|
||||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||||
return fmt.Errorf(errorMsg)
|
return fmt.Errorf("%s", errorMsg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,17 +2,14 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
// usageCache 用于缓存usage数据
|
// usageCache 用于缓存usage数据
|
||||||
@@ -35,10 +32,10 @@ type WindowStats struct {
|
|||||||
|
|
||||||
// UsageProgress 使用量进度
|
// UsageProgress 使用量进度
|
||||||
type UsageProgress struct {
|
type UsageProgress struct {
|
||||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
@@ -65,21 +62,24 @@ type ClaudeUsageResponse struct {
|
|||||||
} `json:"seven_day_sonnet"`
|
} `json:"seven_day_sonnet"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||||
|
type ClaudeUsageFetcher interface {
|
||||||
|
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// AccountUsageService 账号使用量查询服务
|
// AccountUsageService 账号使用量查询服务
|
||||||
type AccountUsageService struct {
|
type AccountUsageService struct {
|
||||||
repos *repository.Repositories
|
accountRepo ports.AccountRepository
|
||||||
oauthService *OAuthService
|
usageLogRepo ports.UsageLogRepository
|
||||||
httpClient *http.Client
|
usageFetcher ClaudeUsageFetcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
|
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
repos: repos,
|
accountRepo: accountRepo,
|
||||||
oauthService: oauthService,
|
usageLogRepo: usageLogRepo,
|
||||||
httpClient: &http.Client{
|
usageFetcher: usageFetcher,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
|
|||||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||||
// API Key账号: 不支持usage查询
|
// API Key账号: 不支持usage查询
|
||||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account failed: %w", err)
|
return nil, fmt.Errorf("get account failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -97,8 +97,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
if account.CanGetUsage() {
|
if account.CanGetUsage() {
|
||||||
// 检查缓存
|
// 检查缓存
|
||||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||||
cache := cached.(*usageCache)
|
cache, ok := cached.(*usageCache)
|
||||||
if time.Since(cache.timestamp) < cacheTTL {
|
if !ok {
|
||||||
|
usageCacheMap.Delete(accountID)
|
||||||
|
} else if time.Since(cache.timestamp) < cacheTTL {
|
||||||
return cache.data, nil
|
return cache.data, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
|||||||
startTime = time.Now().Add(-5 * time.Hour)
|
startTime = time.Now().Add(-5 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
|
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||||
return
|
return
|
||||||
@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
|||||||
|
|
||||||
// GetTodayStats 获取账号今日统计
|
// GetTodayStats 获取账号今日统计
|
||||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||||
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
|
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -175,60 +177,33 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||||
|
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get account usage stats failed: %w", err)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||||
// 获取access token(从credentials中获取)
|
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("no access token available")
|
return nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取代理配置
|
var proxyURL string
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL := account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||||
Transport: transport,
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request failed: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析响应
|
|
||||||
var usageResp ClaudeUsageResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
|
||||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转换为UsageInfo
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return s.buildUsageInfo(&usageResp, &now), nil
|
return s.buildUsageInfo(usageResp, &now), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTime 尝试多种格式解析时间
|
// parseTime 尝试多种格式解析时间
|
||||||
|
|||||||
@@ -2,20 +2,15 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,9 +22,9 @@ type AdminService interface {
|
|||||||
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
||||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
||||||
DeleteUser(ctx context.Context, id int64) error
|
DeleteUser(ctx context.Context, id int64) error
|
||||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
|
||||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
|
|
||||||
// Group management
|
// Group management
|
||||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||||
@@ -50,6 +45,7 @@ type AdminService interface {
|
|||||||
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
||||||
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
||||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
||||||
|
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||||
|
|
||||||
// Proxy management
|
// Proxy management
|
||||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
||||||
@@ -76,6 +72,9 @@ type AdminService interface {
|
|||||||
type CreateUserInput struct {
|
type CreateUserInput struct {
|
||||||
Email string
|
Email string
|
||||||
Password string
|
Password string
|
||||||
|
Username string
|
||||||
|
Wechat string
|
||||||
|
Notes string
|
||||||
Balance float64
|
Balance float64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
AllowedGroups []int64
|
AllowedGroups []int64
|
||||||
@@ -84,6 +83,9 @@ type CreateUserInput struct {
|
|||||||
type UpdateUserInput struct {
|
type UpdateUserInput struct {
|
||||||
Email string
|
Email string
|
||||||
Password string
|
Password string
|
||||||
|
Username *string
|
||||||
|
Wechat *string
|
||||||
|
Notes *string
|
||||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Status string
|
Status string
|
||||||
@@ -119,8 +121,8 @@ type CreateAccountInput struct {
|
|||||||
Name string
|
Name string
|
||||||
Platform string
|
Platform string
|
||||||
Type string
|
Type string
|
||||||
Credentials map[string]interface{}
|
Credentials map[string]any
|
||||||
Extra map[string]interface{}
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
@@ -130,8 +132,8 @@ type CreateAccountInput struct {
|
|||||||
type UpdateAccountInput struct {
|
type UpdateAccountInput struct {
|
||||||
Name string
|
Name string
|
||||||
Type string // Account type: oauth, setup-token, apikey
|
Type string // Account type: oauth, setup-token, apikey
|
||||||
Credentials map[string]interface{}
|
Credentials map[string]any
|
||||||
Extra map[string]interface{}
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||||
@@ -139,6 +141,33 @@ type UpdateAccountInput struct {
|
|||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||||
|
type BulkUpdateAccountsInput struct {
|
||||||
|
AccountIDs []int64
|
||||||
|
Name string
|
||||||
|
ProxyID *int64
|
||||||
|
Concurrency *int
|
||||||
|
Priority *int
|
||||||
|
Status string
|
||||||
|
GroupIDs *[]int64
|
||||||
|
Credentials map[string]any
|
||||||
|
Extra map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkUpdateAccountResult captures the result for a single account update.
|
||||||
|
type BulkUpdateAccountResult struct {
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||||
|
type BulkUpdateAccountsResult struct {
|
||||||
|
Success int `json:"success"`
|
||||||
|
Failed int `json:"failed"`
|
||||||
|
Results []BulkUpdateAccountResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
type CreateProxyInput struct {
|
type CreateProxyInput struct {
|
||||||
Name string
|
Name string
|
||||||
Protocol string
|
Protocol string
|
||||||
@@ -177,44 +206,57 @@ type ProxyTestResult struct {
|
|||||||
Country string `json:"country,omitempty"`
|
Country string `json:"country,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||||
|
type ProxyExitInfo struct {
|
||||||
|
IP string
|
||||||
|
City string
|
||||||
|
Region string
|
||||||
|
Country string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||||
|
type ProxyExitInfoProber interface {
|
||||||
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo *repository.UserRepository
|
userRepo ports.UserRepository
|
||||||
groupRepo *repository.GroupRepository
|
groupRepo ports.GroupRepository
|
||||||
accountRepo *repository.AccountRepository
|
accountRepo ports.AccountRepository
|
||||||
proxyRepo *repository.ProxyRepository
|
proxyRepo ports.ProxyRepository
|
||||||
apiKeyRepo *repository.ApiKeyRepository
|
apiKeyRepo ports.ApiKeyRepository
|
||||||
redeemCodeRepo *repository.RedeemCodeRepository
|
redeemCodeRepo ports.RedeemCodeRepository
|
||||||
usageLogRepo *repository.UsageLogRepository
|
|
||||||
userSubRepo *repository.UserSubscriptionRepository
|
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
|
proxyProber ProxyExitInfoProber
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
func NewAdminService(repos *repository.Repositories) AdminService {
|
func NewAdminService(
|
||||||
|
userRepo ports.UserRepository,
|
||||||
|
groupRepo ports.GroupRepository,
|
||||||
|
accountRepo ports.AccountRepository,
|
||||||
|
proxyRepo ports.ProxyRepository,
|
||||||
|
apiKeyRepo ports.ApiKeyRepository,
|
||||||
|
redeemCodeRepo ports.RedeemCodeRepository,
|
||||||
|
billingCacheService *BillingCacheService,
|
||||||
|
proxyProber ProxyExitInfoProber,
|
||||||
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: repos.User,
|
userRepo: userRepo,
|
||||||
groupRepo: repos.Group,
|
groupRepo: groupRepo,
|
||||||
accountRepo: repos.Account,
|
accountRepo: accountRepo,
|
||||||
proxyRepo: repos.Proxy,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: repos.ApiKey,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: repos.RedeemCode,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
usageLogRepo: repos.UsageLog,
|
billingCacheService: billingCacheService,
|
||||||
userSubRepo: repos.UserSubscription,
|
proxyProber: proxyProber,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
|
||||||
// 注意:AdminService是接口,需要类型断言
|
|
||||||
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
|
|
||||||
if impl, ok := adminService.(*adminServiceImpl); ok {
|
|
||||||
impl.billingCacheService = billingCacheService
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// User management implementations
|
// User management implementations
|
||||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -229,6 +271,9 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User,
|
|||||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
||||||
user := &model.User{
|
user := &model.User{
|
||||||
Email: input.Email,
|
Email: input.Email,
|
||||||
|
Username: input.Username,
|
||||||
|
Wechat: input.Wechat,
|
||||||
|
Notes: input.Notes,
|
||||||
Role: "user", // Always create as regular user, never admin
|
Role: "user", // Always create as regular user, never admin
|
||||||
Balance: input.Balance,
|
Balance: input.Balance,
|
||||||
Concurrency: input.Concurrency,
|
Concurrency: input.Concurrency,
|
||||||
@@ -254,8 +299,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
return nil, errors.New("cannot disable admin user")
|
return nil, errors.New("cannot disable admin user")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track balance and concurrency changes for logging
|
|
||||||
oldBalance := user.Balance
|
|
||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
@@ -266,22 +309,25 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Role is not allowed to be changed via API to prevent privilege escalation
|
|
||||||
|
if input.Username != nil {
|
||||||
|
user.Username = *input.Username
|
||||||
|
}
|
||||||
|
if input.Wechat != nil {
|
||||||
|
user.Wechat = *input.Wechat
|
||||||
|
}
|
||||||
|
if input.Notes != nil {
|
||||||
|
user.Notes = *input.Notes
|
||||||
|
}
|
||||||
|
|
||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
user.Status = input.Status
|
user.Status = input.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只在指针非 nil 时更新 Balance(支持设置为 0)
|
|
||||||
if input.Balance != nil {
|
|
||||||
user.Balance = *input.Balance
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只在指针非 nil 时更新 Concurrency(支持设置为任意值)
|
|
||||||
if input.Concurrency != nil {
|
if input.Concurrency != nil {
|
||||||
user.Concurrency = *input.Concurrency
|
user.Concurrency = *input.Concurrency
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只在指针非 nil 时更新 AllowedGroups
|
|
||||||
if input.AllowedGroups != nil {
|
if input.AllowedGroups != nil {
|
||||||
user.AllowedGroups = *input.AllowedGroups
|
user.AllowedGroups = *input.AllowedGroups
|
||||||
}
|
}
|
||||||
@@ -290,39 +336,15 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 余额变化时失效缓存
|
|
||||||
if input.Balance != nil && *input.Balance != oldBalance {
|
|
||||||
if s.billingCacheService != nil {
|
|
||||||
go func() {
|
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create adjustment records for balance/concurrency changes
|
|
||||||
balanceDiff := user.Balance - oldBalance
|
|
||||||
if balanceDiff != 0 {
|
|
||||||
adjustmentRecord := &model.RedeemCode{
|
|
||||||
Code: model.GenerateRedeemCode(),
|
|
||||||
Type: model.AdjustmentTypeAdminBalance,
|
|
||||||
Value: balanceDiff,
|
|
||||||
Status: model.StatusUsed,
|
|
||||||
UsedBy: &user.ID,
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
adjustmentRecord.UsedAt = &now
|
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
|
||||||
// Log error but don't fail the update
|
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
adjustmentRecord := &model.RedeemCode{
|
adjustmentRecord := &model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: code,
|
||||||
Type: model.AdjustmentTypeAdminConcurrency,
|
Type: model.AdjustmentTypeAdminConcurrency,
|
||||||
Value: float64(concurrencyDiff),
|
Value: float64(concurrencyDiff),
|
||||||
Status: model.StatusUsed,
|
Status: model.StatusUsed,
|
||||||
@@ -331,8 +353,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
adjustmentRecord.UsedAt = &now
|
adjustmentRecord.UsedAt = &now
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
// Log error but don't fail the update
|
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,12 +372,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
|||||||
return s.userRepo.Delete(ctx, id)
|
return s.userRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
|
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldBalance := user.Balance
|
||||||
|
|
||||||
switch operation {
|
switch operation {
|
||||||
case "set":
|
case "set":
|
||||||
user.Balance = balance
|
user.Balance = balance
|
||||||
@@ -366,24 +389,53 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
user.Balance -= balance
|
user.Balance -= balance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.Balance < 0 {
|
||||||
|
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 失效余额缓存
|
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
balanceDiff := user.Balance - oldBalance
|
||||||
|
if balanceDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
adjustmentRecord := &model.RedeemCode{
|
||||||
|
Code: code,
|
||||||
|
Type: model.AdjustmentTypeAdminBalance,
|
||||||
|
Value: balanceDiff,
|
||||||
|
Status: model.StatusUsed,
|
||||||
|
UsedBy: &user.ID,
|
||||||
|
Notes: notes,
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
adjustmentRecord.UsedAt = &now
|
||||||
|
|
||||||
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
|
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -391,9 +443,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
|||||||
return keys, result.Total, nil
|
return keys, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
|
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
return map[string]interface{}{
|
return map[string]any{
|
||||||
"period": period,
|
"period": period,
|
||||||
"total_requests": 0,
|
"total_requests": 0,
|
||||||
"total_cost": 0.0,
|
"total_cost": 0.0,
|
||||||
@@ -404,7 +456,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
|||||||
|
|
||||||
// Group management implementations
|
// Group management implementations
|
||||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -566,7 +618,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
for _, userID := range affectedUserIDs {
|
for _, userID := range affectedUserIDs {
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||||
|
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -575,7 +629,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -585,7 +639,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
|
|
||||||
// Account management implementations
|
// Account management implementations
|
||||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -633,10 +687,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if input.Type != "" {
|
if input.Type != "" {
|
||||||
account.Type = input.Type
|
account.Type = input.Type
|
||||||
}
|
}
|
||||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
if len(input.Credentials) > 0 {
|
||||||
account.Credentials = model.JSONB(input.Credentials)
|
account.Credentials = model.JSONB(input.Credentials)
|
||||||
}
|
}
|
||||||
if input.Extra != nil && len(input.Extra) > 0 {
|
if len(input.Extra) > 0 {
|
||||||
account.Extra = model.JSONB(input.Extra)
|
account.Extra = model.JSONB(input.Extra)
|
||||||
}
|
}
|
||||||
if input.ProxyID != nil {
|
if input.ProxyID != nil {
|
||||||
@@ -668,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||||
|
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||||
|
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||||
|
result := &BulkUpdateAccountsResult{
|
||||||
|
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input.AccountIDs) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare bulk updates for columns and JSONB fields.
|
||||||
|
repoUpdates := ports.AccountBulkUpdate{
|
||||||
|
Credentials: input.Credentials,
|
||||||
|
Extra: input.Extra,
|
||||||
|
}
|
||||||
|
if input.Name != "" {
|
||||||
|
repoUpdates.Name = &input.Name
|
||||||
|
}
|
||||||
|
if input.ProxyID != nil {
|
||||||
|
repoUpdates.ProxyID = input.ProxyID
|
||||||
|
}
|
||||||
|
if input.Concurrency != nil {
|
||||||
|
repoUpdates.Concurrency = input.Concurrency
|
||||||
|
}
|
||||||
|
if input.Priority != nil {
|
||||||
|
repoUpdates.Priority = input.Priority
|
||||||
|
}
|
||||||
|
if input.Status != "" {
|
||||||
|
repoUpdates.Status = &input.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run bulk update for column/jsonb fields first.
|
||||||
|
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle group bindings per account (requires individual operations).
|
||||||
|
for _, accountID := range input.AccountIDs {
|
||||||
|
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||||
|
|
||||||
|
if input.GroupIDs != nil {
|
||||||
|
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
|
||||||
|
entry.Success = false
|
||||||
|
entry.Error = err.Error()
|
||||||
|
result.Failed++
|
||||||
|
result.Results = append(result.Results, entry)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.Success = true
|
||||||
|
result.Success++
|
||||||
|
result.Results = append(result.Results, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||||
return s.accountRepo.Delete(ctx, id)
|
return s.accountRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
@@ -703,7 +816,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
|||||||
|
|
||||||
// Proxy management implementations
|
// Proxy management implementations
|
||||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -788,7 +901,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
|
|||||||
|
|
||||||
// Redeem code management implementations
|
// Redeem code management implementations
|
||||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@@ -818,8 +931,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
|||||||
|
|
||||||
codes := make([]model.RedeemCode, 0, input.Count)
|
codes := make([]model.RedeemCode, 0, input.Count)
|
||||||
for i := 0; i < input.Count; i++ {
|
for i := 0; i < input.Count; i++ {
|
||||||
|
codeValue, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
code := model.RedeemCode{
|
code := model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: codeValue,
|
||||||
Type: input.Type,
|
Type: input.Type,
|
||||||
Value: input.Value,
|
Value: input.Value,
|
||||||
Status: model.StatusUnused,
|
Status: model.StatusUnused,
|
||||||
@@ -872,79 +989,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return testProxyConnection(ctx, proxy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
|
|
||||||
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
|
|
||||||
proxyURL := proxy.URL()
|
proxyURL := proxy.URL()
|
||||||
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||||
// Create HTTP client with proxy
|
|
||||||
transport, err := createProxyTransport(proxyURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: false,
|
Success: false,
|
||||||
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
|
Message: err.Error(),
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
Timeout: 15 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Measure latency
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Failed to create request: %v", err),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Proxy connection failed: %v", err),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
latencyMs := time.Since(startTime).Milliseconds()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse ipinfo.io response
|
|
||||||
var ipInfo struct {
|
|
||||||
IP string `json:"ip"`
|
|
||||||
City string `json:"city"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Country string `json:"country"`
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: true,
|
|
||||||
Message: "Proxy is accessible but failed to read response",
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: true,
|
|
||||||
Message: "Proxy is accessible but failed to parse response",
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -952,38 +1002,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
|
|||||||
Success: true,
|
Success: true,
|
||||||
Message: "Proxy is accessible",
|
Message: "Proxy is accessible",
|
||||||
LatencyMs: latencyMs,
|
LatencyMs: latencyMs,
|
||||||
IPAddress: ipInfo.IP,
|
IPAddress: exitInfo.IP,
|
||||||
City: ipInfo.City,
|
City: exitInfo.City,
|
||||||
Region: ipInfo.Region,
|
Region: exitInfo.Region,
|
||||||
Country: ipInfo.Country,
|
Country: exitInfo.Country,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createProxyTransport creates an HTTP transport with the given proxy URL
|
|
||||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
|
||||||
parsedURL, err := url.Parse(proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
switch parsedURL.Scheme {
|
|
||||||
case "http", "https":
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
case "socks5":
|
|
||||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
|
||||||
}
|
|
||||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
return transport, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"sub2api/internal/model"
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -17,18 +18,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrApiKeyNotFound = errors.New("api key not found")
|
ErrApiKeyNotFound = errors.New("api key not found")
|
||||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||||
ErrApiKeyExists = errors.New("api key already exists")
|
ErrApiKeyExists = errors.New("api key already exists")
|
||||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
apiKeyMaxErrorsPerHour = 20
|
||||||
apiKeyMaxErrorsPerHour = 20
|
|
||||||
apiKeyRateLimitDuration = time.Hour
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateApiKeyRequest 创建API Key请求
|
// CreateApiKeyRequest 创建API Key请求
|
||||||
@@ -47,21 +46,21 @@ type UpdateApiKeyRequest struct {
|
|||||||
|
|
||||||
// ApiKeyService API Key服务
|
// ApiKeyService API Key服务
|
||||||
type ApiKeyService struct {
|
type ApiKeyService struct {
|
||||||
apiKeyRepo *repository.ApiKeyRepository
|
apiKeyRepo ports.ApiKeyRepository
|
||||||
userRepo *repository.UserRepository
|
userRepo ports.UserRepository
|
||||||
groupRepo *repository.GroupRepository
|
groupRepo ports.GroupRepository
|
||||||
userSubRepo *repository.UserSubscriptionRepository
|
userSubRepo ports.UserSubscriptionRepository
|
||||||
rdb *redis.Client
|
cache ports.ApiKeyCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewApiKeyService 创建API Key服务实例
|
// NewApiKeyService 创建API Key服务实例
|
||||||
func NewApiKeyService(
|
func NewApiKeyService(
|
||||||
apiKeyRepo *repository.ApiKeyRepository,
|
apiKeyRepo ports.ApiKeyRepository,
|
||||||
userRepo *repository.UserRepository,
|
userRepo ports.UserRepository,
|
||||||
groupRepo *repository.GroupRepository,
|
groupRepo ports.GroupRepository,
|
||||||
userSubRepo *repository.UserSubscriptionRepository,
|
userSubRepo ports.UserSubscriptionRepository,
|
||||||
rdb *redis.Client,
|
cache ports.ApiKeyCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *ApiKeyService {
|
) *ApiKeyService {
|
||||||
return &ApiKeyService{
|
return &ApiKeyService{
|
||||||
@@ -69,7 +68,7 @@ func NewApiKeyService(
|
|||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
rdb: rdb,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
|||||||
|
|
||||||
// 检查字符:只允许字母、数字、下划线、连字符
|
// 检查字符:只允许字母、数字、下划线、连字符
|
||||||
for _, c := range key {
|
for _, c := range key {
|
||||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
if (c >= 'a' && c <= 'z') ||
|
||||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
(c >= 'A' && c <= 'Z') ||
|
||||||
return ErrApiKeyInvalidChars
|
(c >= '0' && c <= '9') ||
|
||||||
|
c == '_' || c == '-' {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
return ErrApiKeyInvalidChars
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -112,13 +114,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
|||||||
|
|
||||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||||
|
|
||||||
count, err := s.rdb.Get(ctx, key).Int()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
// Redis 出错时不阻止用户操作
|
// Redis 出错时不阻止用户操作
|
||||||
return nil
|
return nil
|
||||||
@@ -133,16 +133,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
|||||||
|
|
||||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||||
if s.rdb == nil {
|
if s.cache == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||||
|
|
||||||
pipe := s.rdb.Pipeline()
|
|
||||||
pipe.Incr(ctx, key)
|
|
||||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
|
||||||
_, _ = pipe.Exec(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||||
@@ -237,7 +232,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 获取用户的API Key列表
|
// List 获取用户的API Key列表
|
||||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||||
@@ -272,7 +267,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
// 这里可以序列化并缓存API Key
|
// 这里可以序列化并缓存API Key
|
||||||
_ = cacheKey // 使用变量避免未使用错误
|
_ = cacheKey // 使用变量避免未使用错误
|
||||||
}
|
}
|
||||||
@@ -325,9 +320,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
if req.Status != nil {
|
if req.Status != nil {
|
||||||
apiKey.Status = *req.Status
|
apiKey.Status = *req.Status
|
||||||
// 如果状态改变,清除Redis缓存
|
// 如果状态改变,清除Redis缓存
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||||
_ = s.rdb.Del(ctx, cacheKey)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,9 +348,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清除Redis缓存
|
// 清除Redis缓存
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||||
_ = s.rdb.Del(ctx, cacheKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||||
@@ -399,13 +392,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
|||||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||||
// 使用Redis计数器
|
// 使用Redis计数器
|
||||||
if s.rdb != nil {
|
if s.cache != nil {
|
||||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||||
return fmt.Errorf("increment usage: %w", err)
|
return fmt.Errorf("increment usage: %w", err)
|
||||||
}
|
}
|
||||||
// 设置24小时过期
|
// 设置24小时过期
|
||||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -462,3 +455,11 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
|
|||||||
// 标准类型分组:使用原有逻辑
|
// 标准类型分组:使用原有逻辑
|
||||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||||
|
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("search api keys: %w", err)
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user