Compare commits

...

16 Commits

Author SHA1 Message Date
hi_yueban
f57f12c6cc fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题 (#30)
* fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题

问题描述:
- 账号管理页面中,OpenAI OAuth 账号的 5h 列显示 7 天的剩余时间
- 7d 列却显示几小时的剩余时间
- 根本原因: OpenAI 响应头中 primary/secondary 的实际含义与代码假设相反

修复方案:
1. 后端归一化 (openai_gateway_service.go):
   - 根据 window_minutes 动态判断哪个是 5h/7d 限制
   - 新增规范字段 codex_5h_* 和 codex_7d_*
   - 保留旧字段以兼容性

2. 前端适配 (AccountUsageCell.vue):
   - 优先使用新的规范字段
   - Fallback 到旧字段时基于 window_minutes 动态判断
   - 更新 computed 属性命名

3. 类型定义更新 (types/index.ts):
   - 添加新的规范字段定义
   - 更新注释说明实际语义由 window_minutes 决定

🤖 Generated with Claude Code and Codex collaboration

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: OpenAI Codex <noreply@openai.com>

* fix: 改进窗口判断逻辑,修复两个窗口都小于阈值时的bug

问题:
当两个窗口都小于360分钟时(如 primary=180分钟,secondary=300分钟),
之前的逻辑会导致:
- primary5h = true, secondary5h = true
- 5h 字段会使用 primary(错误)
- 7d 字段没有数据(bug)

修复方案:
改用比较策略:
1. 当两个窗口都存在时:较小的分配给5h,较大的分配给7d
2. 当只有一个窗口时:根据大小(<=360分钟)判断是5h还是7d
3. 确保数据不会丢失,逻辑更健壮

示例:
- Primary: 180分钟, Secondary: 300分钟
  → 5h 使用 Primary(180分钟), 7d 使用 Secondary(300分钟) ✓

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修正窗口大小判断逻辑 - 不能用剩余时间判断窗口类型

**严重bug修复:**
之前的 fallback 逻辑错误地使用 reset_after_seconds 来判断窗口大小。

问题示例:
- 周限制(7d)剩余 2h → reset_after_seconds = 7200秒
- 5h限制 剩余 4h → reset_after_seconds = 14400秒
- 错误逻辑:7200/60 < 14400/60,把周限制当成5h 

根本问题:
- window_minutes = 窗口的总大小(300 or 10080)
- reset_after_seconds = 距离重置的剩余时间(变化的)
- 不能用剩余时间来判断窗口类型!

修复方案:
1. **只使用 window_minutes** 来判断窗口大小
2. 移除错误的 reset_after_seconds fallback
3. 如果 window_minutes 都不存在,使用传统假设
4. 添加详细注释说明这个陷阱

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修复 lint 问题 - 改进 fallback 逻辑的变量赋值

问题:
第882-883行的简单布尔赋值可能触发 ineffassign 或 staticcheck 警告:
  use5hFromSecondary = snapshot.SecondaryUsedPercent != nil
  use7dFromPrimary = snapshot.PrimaryUsedPercent != nil

修复:
改用明确的 if 语句检查任意字段是否存在,更符合代码意图:
- 如果 secondary 的任意字段存在,将其视为 5h
- 如果 primary 的任意字段存在,将其视为 7d

这样逻辑更清晰,也避免了 lint 警告。

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: OpenAI Codex <noreply@openai.com>
2025-12-25 17:00:02 +08:00
shaw
5fca2d10b9 fix: 修复image地址 2025-12-25 16:57:29 +08:00
shaw
8fbe1ad70d chore: 调整403重试次数跟间隔 2025-12-25 16:23:31 +08:00
Forest
25a304c231 test: 增加 repository 测试 2025-12-25 16:01:17 +08:00
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
62 changed files with 10647 additions and 307 deletions

View File

@@ -17,9 +17,12 @@ jobs:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: Run tests
- name: Unit tests
working-directory: backend
run: go test ./...
run: make test-unit
- name: Integration tests
working-directory: backend
run: make test-integration
golangci-lint:
runs-on: ubuntu-latest

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24'
cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations
run: |
# 确保获取完整的 annotated tag 信息
@@ -117,87 +130,16 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# ===========================================================================
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
# Update DockerHub description
- name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md

7
.gitignore vendored
View File

@@ -92,6 +92,13 @@ backend/internal/web/dist/*
# 后端运行时缓存数据
backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# ===================
# 其他
# ===================

View File

@@ -52,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息
disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release:
github:
owner: Wei-Shaw
name: sub2api
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false
prerelease: auto
name_template: "Sub2API {{.Version}}"
@@ -73,7 +121,7 @@ release:
**One-line install (Linux):**
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
```
**Manual download:**
@@ -81,5 +129,5 @@ release:
## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -1,4 +1,4 @@
.PHONY: wire build build-embed
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
wire:
@echo "生成 Wire 代码..."
@@ -13,4 +13,21 @@ build:
build-embed:
@echo "构建后端(嵌入前端)..."
@go build -tags embed -o bin/server ./cmd/server
@echo "构建完成: bin/server (with embedded frontend)"
@echo "构建完成: bin/server (with embedded frontend)"
test-unit:
@go test ./... $(TEST_ARGS)
test-integration:
@go test -tags integration ./internal/repository -count=1 -race -parallel=8
test-cover-integration:
@echo "运行集成测试并生成覆盖率报告..."
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/...
@go tool cover -func=coverage.out | tail -1
@go tool cover -html=coverage.out -o coverage.html
@echo "覆盖率报告已生成: coverage.html"
clean-coverage:
@rm -f coverage.out coverage.html
@echo "覆盖率文件已清理"

View File

@@ -86,7 +86,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -8,10 +8,16 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0
github.com/redis/go-redis/v9 v9.7.3
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
@@ -21,51 +27,99 @@ require (
)
require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
@@ -75,6 +129,7 @@ require (
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

View File

@@ -1,3 +1,11 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -7,17 +15,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
@@ -28,6 +62,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -40,9 +81,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
@@ -54,6 +94,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -64,8 +106,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -85,36 +129,70 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@@ -128,6 +206,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -135,16 +215,55 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -164,8 +283,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
@@ -177,9 +302,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
@@ -192,6 +323,6 @@ gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -34,10 +34,20 @@ type AccountHandler struct {
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, concurrencyService *service.ConcurrencyService) *AccountHandler {
func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
@@ -46,6 +56,7 @@ func NewAccountHandler(adminService service.AdminService, oauthService *service.
accountUsageService: accountUsageService,
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
}
}
@@ -76,6 +87,19 @@ type UpdateAccountRequest struct {
GroupIDs *[]int64 `json:"group_ids"`
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
@@ -224,6 +248,13 @@ type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -244,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
}
}
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -387,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL

View File

@@ -2,11 +2,14 @@ package repository
import (
"context"
"errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AccountRepository struct {
@@ -39,6 +42,22 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &account, nil
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
@@ -335,3 +354,47 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,580 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *AccountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db)
}
func TestAccountRepoSuite(t *testing.T) {
suite.Run(t, new(AccountRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{
Name: "test-create",
Platform: model.PlatformAnthropic,
Type: model.AccountTypeOAuth,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, account)
s.Require().NoError(err, "Create")
s.Require().NotZero(account.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *AccountRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, account.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(accounts, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
platform string
accType string
status string
search string
wantCount int
validate func(accounts []model.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
},
platform: model.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
},
accType: model.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
},
status: model.StatusDisabled,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db)
ctx := context.Background()
tt.setup(db)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
tt.validate(accounts)
}
})
}
}
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
s.Require().Len(accounts, 2)
// Should be ordered by priority
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(accounts, 1)
s.Require().Equal("active1", accounts[0].Name)
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.Proxy, "expected Proxy preload")
s.Require().Equal(proxy.ID, got.Proxy.ID)
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
s.Require().Equal(group.ID, got.GroupIDs[0])
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
}
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups")
s.Require().Len(groups, 1, "expected 1 group")
s.Require().Equal(g1.ID, groups[0].ID)
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after remove")
s.Require().Empty(groups, "expected 0 groups after remove")
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after bind")
s.Require().Len(groups, 2, "expected 2 groups after bind")
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Empty(groups, "expected 0 groups after binding empty list")
}
// --- Schedulable ---
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
ids := idsOfAccounts(sched)
s.Require().Contains(ids, okAcc.ID)
s.Require().NotContains(ids, overloaded.ID)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID")
s.Require().Len(sched, 1, "expected only ok account schedulable")
s.Require().Equal(okAcc.ID, sched[0].ID)
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.OverloadUntil)
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.RateLimitedAt)
s.Require().NotNil(got.RateLimitResetAt)
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Nil(got.RateLimitedAt)
s.Require().Nil(got.RateLimitResetAt)
s.Require().Nil(got.OverloadUntil)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.LastUsedAt)
}
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.SessionWindowStart)
s.Require().NotNil(got.SessionWindowEnd)
s.Require().Equal("active", got.SessionWindowStatus)
}
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc-extra",
Extra: model.JSONB{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("1", got.Extra["a"])
s.Require().Equal("2", got.Extra["b"])
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("val", got.Extra["key"])
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
s.Require().NoError(err)
s.Require().NotNil(got)
s.Require().Equal("acc-crs", got.Name)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
s.Require().NoError(err)
s.Require().Nil(got)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
s.Require().NoError(err)
s.Require().Nil(got)
}
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{
Priority: &newPriority,
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
s.Require().Equal(99, got1.Priority)
s.Require().Equal(99, got2.Priority)
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("value", got.Credentials["existing"])
s.Require().Equal("new_value", got.Credentials["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("val", got.Extra["existing"])
s.Require().Equal("new_val", got.Extra["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []model.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)
}
return out
}

View File

@@ -0,0 +1,125 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
_, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
},
},
{
name: "increment_increases_count_and_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "GetCreateAttemptCount")
require.Equal(s.T(), 2, count, "count mismatch")
ttl, err := rdb.TTL(ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
},
},
{
name: "delete_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
_, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "increment_increases_count",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
n, err := rdb.Get(ctx, dailyKey).Int()
require.NoError(s.T(), err, "Get dailyKey")
require.Equal(s.T(), 2, n, "expected daily usage=2")
},
},
{
name: "set_expiry_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test-expiry"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
ttl, err := rdb.TTL(ctx, dailyKey).Result()
require.NoError(s.T(), err, "TTL dailyKey")
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -0,0 +1,355 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ApiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
key := &model.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, key)
s.Require().NoError(err, "Create")
s.Require().NotZero(key.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User, "expected User preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: model.StatusActive,
})
key.Name = "Renamed"
key.Status = model.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
})
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err)
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
}
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, key.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
s.Require().Equal(int64(2), count)
}
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
// User preloaded
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), count)
}
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists)
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(2), affected)
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
s.Require().Nil(got1.GroupID)
s.Require().Nil(got2.GroupID)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User)
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group)
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
got2, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got3, err := s.repo.GetByID(s.ctx, k2.ID)
s.Require().NoError(err, "GetByID")
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}

View File

@@ -0,0 +1,283 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -16,20 +16,28 @@ import (
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
return &claudeOAuthService{
baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL,
clientFactory: createReqClient,
}
}
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{
"response_type": "code",
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
@@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
@@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -0,0 +1,343 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
method string
cookies []*http.Cookie
body []byte
formValues url.Values
bodyJSON map[string]any
contentType string
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
errContain string
wantUUID string
validate func(captured requestCapture)
}{
{
name: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
},
wantUUID: "org-1",
validate: func(captured requestCapture) {
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
require.Equal(s.T(), "sess", captured.cookies[0].Value)
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
errContain: "401",
},
{
name: "invalid_json_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("not-json"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
if tt.wantErr {
require.Error(s.T(), err)
if tt.errContain != "" {
require.ErrorContains(s.T(), err, tt.errContain)
}
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantUUID, got)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantCode string
validate func(captured requestCapture)
}{
{
name: "parses_redirect_uri",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
})
},
wantCode: "AUTH#STATE",
validate: func(captured requestCapture) {
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sess", captured.cookies[0].Value)
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "st", captured.bodyJSON["state"])
},
},
{
name: "missing_code_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
})
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantCode, code)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct {
name string
handler http.HandlerFunc
code string
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_state_when_embedded",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 3600,
RefreshToken: "rt",
Scope: "s",
})
},
code: "AUTH#STATE2",
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
},
code: "AUTH",
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_form",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
},
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.body, _ = io.ReadAll(r.Body)
captured.formValues, _ = url.ParseQuery(string(captured.body))
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func TestClaudeOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeOAuthServiceSuite))
}

View File

@@ -12,10 +12,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
type claudeUsageService struct{}
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeUsageServiceSuite struct {
suite.Suite
srv *httptest.Server
fetcher *claudeUsageService
}
func (s *ClaudeUsageServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type usageRequestCapture struct {
authorization string
anthropicBeta string
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
// Assertions on captured request data
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 401")
require.ErrorContains(s.T(), err, "nope")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "decode response failed")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := s.fetcher.FetchUsage(ctx, "at", "")
require.Error(s.T(), err, "expected error for cancelled context")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -0,0 +1,231 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache ports.ConcurrencyCache
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type EmailCacheSuite struct {
IntegrationRedisSuite
cache ports.EmailCache
}
func (s *EmailCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewEmailCache(s.rdb)
}
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
}
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com"
emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
got, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode")
require.Equal(s.T(), "123456", got.Code)
require.Equal(s.T(), 1, got.Attempts)
}
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com"
emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
emailKey := verifyCodeKeyPrefix + email
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
require.NoError(s.T(), err, "TTL emailKey")
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
}
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com"
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
// Verify it exists
_, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode before delete")
// Delete
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
// Verify it's gone
_, err = s.cache.GetVerificationCode(s.ctx, email)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
// Deleting a non-existent key should not error
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
}
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func TestEmailCacheSuite(t *testing.T) {
suite.Run(t, new(EmailCacheSuite))
}

View File

@@ -0,0 +1,172 @@
//go:build integration
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
t.Helper()
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = model.RoleUser
}
if u.Status == "" {
u.Status = model.StatusActive
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = u.CreatedAt
}
require.NoError(t, db.Create(u).Error, "create user")
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
t.Helper()
if g.Platform == "" {
g.Platform = model.PlatformAnthropic
}
if g.Status == "" {
g.Status = model.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
}
if g.UpdatedAt.IsZero() {
g.UpdatedAt = g.CreatedAt
}
require.NoError(t, db.Create(g).Error, "create group")
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
t.Helper()
if p.Protocol == "" {
p.Protocol = "http"
}
if p.Host == "" {
p.Host = "127.0.0.1"
}
if p.Port == 0 {
p.Port = 8080
}
if p.Status == "" {
p.Status = model.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
}
if p.UpdatedAt.IsZero() {
p.UpdatedAt = p.CreatedAt
}
require.NoError(t, db.Create(p).Error, "create proxy")
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
t.Helper()
if a.Platform == "" {
a.Platform = model.PlatformAnthropic
}
if a.Type == "" {
a.Type = model.AccountTypeOAuth
}
if a.Status == "" {
a.Status = model.StatusActive
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = model.JSONB{}
}
if a.Extra == nil {
a.Extra = model.JSONB{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
}
if a.UpdatedAt.IsZero() {
a.UpdatedAt = a.CreatedAt
}
require.NoError(t, db.Create(a).Error, "create account")
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
t.Helper()
if k.Status == "" {
k.Status = model.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
}
if k.UpdatedAt.IsZero() {
k.UpdatedAt = k.CreatedAt
}
require.NoError(t, db.Create(k).Error, "create api key")
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
t.Helper()
if c.Status == "" {
c.Status = model.StatusUnused
}
if c.Type == "" {
c.Type = model.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
}
require.NoError(t, db.Create(c).Error, "create redeem code")
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
t.Helper()
if s.Status == "" {
s.Status = model.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
s.StartsAt = now.Add(-1 * time.Hour)
}
if s.ExpiresAt.IsZero() {
s.ExpiresAt = now.Add(24 * time.Hour)
}
if s.AssignedAt.IsZero() {
s.AssignedAt = now
}
if s.CreatedAt.IsZero() {
s.CreatedAt = now
}
if s.UpdatedAt.IsZero() {
s.UpdatedAt = now
}
require.NoError(t, db.Create(s).Error, "create user subscription")
return s
}
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}).Error, "create account_group")
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GatewayCacheSuite struct {
IntegrationRedisSuite
cache ports.GatewayCache
}
func (s *GatewayCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGatewayCache(s.rdb)
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,328 @@
package repository
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GitHubReleaseServiceSuite struct {
suite.Suite
srv *httptest.Server
client *githubReleaseClient
tempDir string
}
// testTransport redirects requests to the test server
type testTransport struct {
testServerURL string
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to our test server
testURL := t.testServerURL + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultTransport.RoundTrip(newReq)
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
func (s *GitHubReleaseServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized chunked download")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
require.NoError(s.T(), err, "expected success")
b, err := os.ReadFile(dest)
require.NoError(s.T(), err, "read")
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
require.Len(s.T(), b, 100, "downloaded content length mismatch")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for 404")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for 404")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
dest := filepath.Join(s.tempDir, "cancelled.bin")
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for invalid destination path")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
releaseJSON := `{
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"body": "Release notes",
"html_url": "https://github.com/test/repo/releases/v1.0.0",
"assets": [
{
"name": "app-linux-amd64.tar.gz",
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
}
]
}`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(releaseJSON))
}))
// Use custom transport to redirect requests to test server
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.NoError(s.T(), err)
require.Equal(s.T(), "v1.0.0", release.TagName)
require.Equal(s.T(), "Release 1.0.0", release.Name)
require.Len(s.T(), release.Assets, 1)
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
require.Contains(s.T(), err.Error(), "404")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
require.Error(s.T(), err)
}
func TestGitHubReleaseServiceSuite(t *testing.T) {
suite.Run(t, new(GitHubReleaseServiceSuite))
}

View File

@@ -0,0 +1,244 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *GroupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db)
}
func TestGroupRepoSuite(t *testing.T) {
suite.Run(t, new(GroupRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{
Name: "test-create",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, group)
s.Require().NoError(err, "Create")
s.Require().NotZero(group.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(groups, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
Name: "g1",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
Name: "g2",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
IsExclusive: true,
})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
}
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(groups, 1)
s.Require().Equal("active1", groups[0].Name)
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name)
}
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
s.Require().True(exists)
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}

View File

@@ -0,0 +1,115 @@
package repository
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type HTTPUpstreamSuite struct {
suite.Suite
cfg *config.Config
}
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
}
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
got := svc.createProxyClient("://bad-proxy-url")
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
}
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
seen := make(chan string, 1)
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
_, _ = io.WriteString(w, "proxied")
}))
s.T().Cleanup(proxySrv.Close)
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body")
select {
case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b))
}
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}

View File

@@ -0,0 +1,67 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type IdentityCacheSuite struct {
IntegrationRedisSuite
cache *identityCache
}
func (s *IdentityCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewIdentityCache(s.rdb).(*identityCache)
}
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
_, err := s.cache.GetFingerprint(s.ctx, 1)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
}
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
require.NoError(s.T(), err, "GetFingerprint")
require.Equal(s.T(), "c1", gotFP.ClientID)
require.Equal(s.T(), "ua", gotFP.UserAgent)
}
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
require.NoError(s.T(), err, "TTL fpKey")
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
}
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetFingerprint(s.ctx, 999)
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
err := s.cache.SetFingerprint(s.ctx, 100, nil)
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
}
func TestIdentityCacheSuite(t *testing.T) {
suite.Run(t, new(IdentityCacheSuite))
}

View File

@@ -0,0 +1,369 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
gormpostgres "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
redisImageTag = "redis:8.4-alpine"
postgresImageTag = "postgres:18.1-alpine3.23"
)
var (
integrationDB *gorm.DB
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
func TestMain(m *testing.M) {
ctx := context.Background()
if err := timezone.Init("UTC"); err != nil {
log.Printf("failed to init timezone: %v", err)
os.Exit(1)
}
if !dockerIsAvailable(ctx) {
// In CI we expect Docker to be available so integration tests should fail loudly.
if os.Getenv("CI") != "" {
log.Printf("docker is not available (CI=true); failing integration tests")
os.Exit(1)
}
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
os.Exit(0)
}
postgresImage := selectDockerImage(ctx, postgresImageTag)
pgContainer, err := tcpostgres.Run(
ctx,
postgresImage,
tcpostgres.WithDatabase("sub2api_test"),
tcpostgres.WithUsername("postgres"),
tcpostgres.WithPassword("postgres"),
tcpostgres.BasicWaitStrategies(),
)
if err != nil {
log.Printf("failed to start postgres container: %v", err)
os.Exit(1)
}
defer func() { _ = pgContainer.Terminate(ctx) }()
redisContainer, err := tcredis.Run(
ctx,
redisImageTag,
)
if err != nil {
log.Printf("failed to start redis container: %v", err)
os.Exit(1)
}
defer func() { _ = redisContainer.Terminate(ctx) }()
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
if err != nil {
log.Printf("failed to get postgres dsn: %v", err)
os.Exit(1)
}
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
os.Exit(1)
}
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
if err != nil {
log.Printf("failed to get redis port: %v", err)
os.Exit(1)
}
integrationRedis = redisclient.NewClient(&redisclient.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
if err := integrationRedis.Ping(ctx).Err(); err != nil {
log.Printf("failed to ping redis: %v", err)
os.Exit(1)
}
code := m.Run()
_ = integrationRedis.Close()
os.Exit(code)
}
func dockerIsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "docker", "info")
cmd.Env = os.Environ()
return cmd.Run() == nil
}
func selectDockerImage(ctx context.Context, preferred string) string {
if dockerImageExists(ctx, preferred) {
return preferred
}
return preferred
}
func dockerImageExists(ctx context.Context, image string) bool {
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
cmd.Env = os.Environ()
cmd.Stdout = nil
cmd.Stderr = nil
return cmd.Run() == nil
}
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
sqlDB, err := db.DB()
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
return db, nil
}
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
}
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *gorm.DB {
t.Helper()
tx := integrationDB.Begin()
require.NoError(t, tx.Error, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback().Error
})
return tx
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
prefix := fmt.Sprintf(
"it:%s:%d:%d:",
sanitizeRedisNamespace(t.Name()),
time.Now().UnixNano(),
atomic.AddUint64(&redisNamespaceSeq, 1),
)
opts := *integrationRedis.Options()
rdb := redisclient.NewClient(&opts)
rdb.AddHook(prefixHook{prefix: prefix})
t.Cleanup(func() {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
require.NoError(t, err, "scan redis keys for cleanup")
if len(keys) > 0 {
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
}
cursor = nextCursor
if cursor == 0 {
break
}
}
_ = rdb.Close()
})
return rdb
}
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
t.Helper()
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
}
func sanitizeRedisNamespace(name string) string {
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, " ", "_")
return name
}
type prefixHook struct {
prefix string
}
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
return func(ctx context.Context, cmd redisclient.Cmder) error {
h.prefixCmd(cmd)
return next(ctx, cmd)
}
}
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
return func(ctx context.Context, cmds []redisclient.Cmder) error {
for _, cmd := range cmds {
h.prefixCmd(cmd)
}
return next(ctx, cmds)
}
}
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
args := cmd.Args()
if len(args) < 2 {
return
}
prefixOne := func(i int) {
if i < 0 || i >= len(args) {
return
}
switch v := args[i].(type) {
case string:
if v != "" && !strings.HasPrefix(v, h.prefix) {
args[i] = h.prefix + v
}
case []byte:
s := string(v)
if s != "" && !strings.HasPrefix(s, h.prefix) {
args[i] = []byte(h.prefix + s)
}
}
}
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "eval", "evalsha", "eval_ro", "evalsha_ro":
if len(args) < 3 {
return
}
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
if err != nil || numKeys <= 0 {
return
}
for i := 0; i < numKeys && 3+i < len(args); i++ {
prefixOne(3 + i)
}
case "scan":
for i := 2; i+1 < len(args); i++ {
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
prefixOne(i + 1)
break
}
}
}
}
// IntegrationRedisSuite provides a base suite for Redis integration tests.
// Embedding suites should call SetupTest to initialize ctx and rdb.
type IntegrationRedisSuite struct {
suite.Suite
ctx context.Context
rdb *redisclient.Client
}
// SetupTest initializes ctx and rdb for each test method.
func (s *IntegrationRedisSuite) SetupTest() {
s.ctx = context.Background()
s.rdb = testRedis(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
// AssertTTLWithin asserts that ttl is within [min, max].
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
s.T().Helper()
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
// Embedding suites should call SetupTest to initialize ctx and db.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
}
// SetupTest initializes ctx and db for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}

View File

@@ -12,11 +12,13 @@ import (
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
return &openaiOAuthService{tokenURL: openai.TokenURL}
}
type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)

View File

@@ -0,0 +1,249 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type OpenAIOAuthServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
svc *openaiOAuthService
received chan url.Values
}
func (s *OpenAIOAuthServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
}
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
errCh <- "method mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code"); got != "code" {
errCh <- "code mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code_verifier"); got != "ver" {
errCh <- "code_verifier mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at", resp.AccessToken)
require.Equal(s.T(), "rt", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("refresh_token"); got != "rt" {
errCh <- "refresh_token mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
errCh <- "scope mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at2", resp.AccessToken)
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
}
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
want := "http://localhost:9999/cb"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("redirect_uri"); got != want {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
s.received <- r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "unauthorized")
}))
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.Error(s.T(), err, "expected error for non-2xx status")
require.ErrorContains(s.T(), err, "status 401")
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -0,0 +1,147 @@
package repository
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type PricingServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
client *pricingRemoteClient
}
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
func (s *PricingServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ok" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
require.NoError(s.T(), err, "FetchPricingJSON")
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hashfile":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
case "/hashonly":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("def456\n"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "abc123", hash, "hash mismatch")
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "def456", hash2, "hash mismatch")
}
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// empty body
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
require.NoError(s.T(), err, "FetchHashText empty body should not error")
require.Equal(s.T(), "", hash, "expected empty hash")
}
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(" \n"))
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}

View File

@@ -16,10 +16,14 @@ import (
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}

View File

@@ -0,0 +1,121 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ProxyProbeServiceSuite struct {
suite.Suite
ctx context.Context
proxySrv *httptest.Server
prober *proxyProbeService
}
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
if s.proxySrv != nil {
s.proxySrv.Close()
s.proxySrv = nil
}
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
_, err := createProxyTransport("://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.proxySrv.Close()
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}

View File

@@ -0,0 +1,302 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ProxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db)
}
func TestProxyRepoSuite(t *testing.T) {
suite.Run(t, new(ProxyRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
s.Require().NoError(err, "Create")
s.Require().NotZero(proxy.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, proxy.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(proxies, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal("socks5", proxies[0].Protocol)
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Contains(proxies[0].Name, "production")
}
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(proxies, 1)
s.Require().Equal("active1", proxies[0].Name)
}
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists)
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
s.Require().NoError(err)
s.Require().False(notExists)
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
s.Require().NoError(err)
s.Require().True(exists)
}
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count)
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
}
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err)
s.Require().Empty(counts)
}
// --- ListActiveWithAccountCount ---
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Status: model.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Status: model.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p3-inactive",
Status: model.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 active proxies")
// Sorted by created_at DESC, so p2 first
s.Require().Equal(p2.ID, withCounts[0].ID)
s.Require().Equal(int64(1), withCounts[0].AccountCount)
s.Require().Equal(p1.ID, withCounts[1].ID)
s.Require().Equal(int64(2), withCounts[1].AccountCount)
}
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "u",
Password: "p",
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 proxies")
for _, pc := range withCounts {
switch pc.ID {
case p1.ID:
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
case p2.ID:
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
default:
s.Require().Fail("unexpected proxy id", pc.ID)
}
}
}

View File

@@ -0,0 +1,105 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type RedeemCacheSuite struct {
IntegrationRedisSuite
cache *redeemCache
}
func (s *RedeemCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
}
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
userID := int64(1)
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err, "GetRedeemAttemptCount")
require.Equal(s.T(), 1, count, "count mismatch")
ttl, err := s.rdb.TTL(s.ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
}
func (s *RedeemCacheSuite) TestMultipleIncrements() {
userID := int64(2)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, count, "count after 3 increments")
}
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock")
require.True(s.T(), ok)
// Second acquire should fail
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock 2")
require.False(s.T(), ok, "expected lock to be held")
// Release
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
// Now acquire should succeed
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock after release")
require.True(s.T(), ok)
}
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
lockKey := redeemLockKeyPrefix + "CODE2"
lockTTL := 15 * time.Second
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
require.NoError(s.T(), err, "TTL lock key")
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
}
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
// Release a lock that doesn't exist should not error
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
// Acquire, release, release again
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
}
func TestRedeemCacheSuite(t *testing.T) {
suite.Run(t, new(RedeemCacheSuite))
}

View File

@@ -0,0 +1,315 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *RedeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
suite.Run(t, new(RedeemCodeRepoSuite))
}
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{
Code: "TEST-CREATE",
Type: model.RedeemTypeBalance,
Value: 100,
Status: model.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
s.Require().NoError(err, "Create")
s.Require().NotZero(code.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
s.Require().NoError(err, "CreateBatch")
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
s.Require().NoError(err)
s.Require().Equal(float64(10), got1.Value)
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
s.Require().NoError(err)
s.Require().Equal(float64(20), got2.Value)
}
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
s.Require().Equal("GET-BY-CODE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
s.Require().Error(err, "expected error for non-existent code")
}
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, code.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(codes, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Contains(codes[0].Code, "ALPHA")
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription,
GroupID: &group.ID,
})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group, "expected Group preload")
s.Require().Equal(group.ID, codes[0].Group.ID)
}
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
code.Value = 50
err := s.repo.Update(s.ctx, code)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(float64(50), got.Value)
}
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
// Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-1",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-2",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(codes, 2)
// Ordered by used_at DESC, so USER-2 first
s.Require().Equal("USER-2", codes[0].Code)
s.Require().Equal("USER-1", codes[1].Code)
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GRP",
Type: model.RedeemTypeSubscription,
Status: model.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
s.db.Model(c).Update("used_at", time.Now())
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group)
s.Require().Equal(group.ID, codes[0].Group.ID)
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "DEF-LIM",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
// limit <= 0 should default to 10
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
s.Require().NoError(err)
s.Require().Len(codes, 1)
}
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
codes := []model.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
s.Require().NotNil(list[0].Group, "expected Group preload")
s.Require().Equal(group.ID, list[0].Group.ID)
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
s.Require().NoError(err, "GetByCode")
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(used, 2, "expected 2 used codes")
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
}

View File

@@ -0,0 +1,108 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type SettingRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *SettingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db)
}
func TestSettingRepoSuite(t *testing.T) {
suite.Run(t, new(SettingRepoSuite))
}
func (s *SettingRepoSuite) TestSetAndGetValue() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue")
s.Require().Equal("v1", got, "GetValue mismatch")
}
func (s *SettingRepoSuite) TestSet_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue after upsert")
s.Require().Equal("v2", got, "upsert mismatch")
}
func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
s.Require().NoError(err, "GetMultiple")
s.Require().Equal("v2", m["k2"])
s.Require().Equal("v3", m["k3"])
}
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
m, err := s.repo.GetMultiple(s.ctx, []string{})
s.Require().NoError(err, "GetMultiple with empty keys")
s.Require().Empty(m, "expected empty map")
}
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
s.Require().NoError(err, "GetMultiple subset")
s.Require().Equal("1", m["a"])
s.Require().Equal("3", m["c"])
_, exists := m["nonexistent"]
s.Require().False(exists, "nonexistent key should not be in map")
}
func (s *SettingRepoSuite) TestGetAll() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
all, err := s.repo.GetAll(s.ctx)
s.Require().NoError(err, "GetAll")
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
s.Require().Equal("1", all["x"])
s.Require().Equal("2", all["y"])
}
func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestDelete_Idempotent() {
// Delete a key that doesn't exist should not error
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
}
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
got, err := s.repo.GetValue(s.ctx, "upsert_key")
s.Require().NoError(err)
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
got2, err := s.repo.GetValue(s.ctx, "new_key")
s.Require().NoError(err)
s.Require().Equal("new_val", got2)
}

View File

@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
type turnstileVerifier struct {
httpClient *http.Client
verifyURL string
}
func NewTurnstileVerifier() service.TurnstileVerifier {
@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifyURL: turnstileVerifyURL,
}
}
@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}

View File

@@ -0,0 +1,143 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
func (s *TurnstileServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
require.True(s.T(), ok, "type assertion failed")
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken")
require.NotNil(s.T(), resp)
require.True(s.T(), resp.Success, "expected success response")
// Assert form fields in main goroutine
select {
case values := <-s.received:
require.Equal(s.T(), "sk", values.Get("secret"))
require.Equal(s.T(), "token", values.Get("response"))
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err)
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
require.NoError(s.T(), err)
select {
case values := <-s.received:
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,
ErrorCodes: []string{"invalid-input-response"},
})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
require.NotNil(s.T(), resp)
require.False(s.T(), resp.Success)
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
}
func TestTurnstileServiceSuite(t *testing.T) {
suite.Run(t, new(TurnstileServiceSuite))
}

View File

@@ -0,0 +1,73 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type UpdateCacheSuite struct {
IntegrationRedisSuite
cache *updateCache
}
func (s *UpdateCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewUpdateCache(s.rdb).(*updateCache)
}
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
_, err := s.cache.GetUpdateInfo(s.ctx)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
}
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err, "GetUpdateInfo")
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err, "TTL updateCacheKey")
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
// TTL=0 means persist forever (no expiry) in Redis SET command
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v0.0.0", info)
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err)
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
}
func TestUpdateCacheSuite(t *testing.T) {
suite.Run(t, new(UpdateCacheSuite))
}

View File

@@ -0,0 +1,890 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UsageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db)
}
func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalCost: cost,
ActualCost: cost,
CreatedAt: createdAt,
}
s.Require().NoError(s.repo.Create(s.ctx, log))
return log
}
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.4,
}
err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(log.ID, got.ID)
s.Require().Equal(10, got.InputTokens)
}
func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
err := s.repo.Delete(s.ctx, log.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, log.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUser")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAccount")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.InputTokens)
s.Require().Equal(int64(45), stats.OutputTokens)
}
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{UserID: user.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetDashboardStats ---
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &model.User{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 3,
CacheReadTokens: 4,
TotalCost: 1.5,
ActualCost: 1.2,
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 6,
TotalCost: 0.7,
ActualCost: 0.7,
DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
OutputTokens: 2,
TotalCost: 0.1,
ActualCost: 0.1,
DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second),
}
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
}
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
}
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
s.Require().NotNil(stats[user2.ID])
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour))
s.Require().NoError(err, "GetGlobalStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.TotalInputTokens)
s.Require().Equal(int64(45), stats.TotalOutputTokens)
}
func maxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "ListByUserAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "ListByAccountAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
OutputTokens: 25,
TotalCost: 0.6,
ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
OutputTokens: 30,
TotalCost: 0.7,
ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log3))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime)
s.Require().NoError(err, "ListByModelAndTimeRange")
s.Require().Len(logs, 2)
}
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute))
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window
stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart)
s.Require().NoError(err, "GetAccountWindowStats")
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25)
}
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day")
s.Require().NoError(err, "GetUserUsageTrendByUserID")
s.Require().Len(trend, 2) // 2 different days
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour")
s.Require().NoError(err, "GetUserUsageTrendByUserID hourly")
s.Require().Len(trend, 3) // 3 different hours
}
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserModelStats")
s.Require().Len(stats, 2)
// Should be ordered by total_tokens DESC
s.Require().Equal("claude-3-opus", stats[0].Model)
s.Require().Equal(int64(300), stats[0].TotalTokens)
}
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats")
s.Require().Len(resp.History, 2, "expected 2 days of history")
s.Require().Equal(int64(2), resp.Summary.TotalRequests)
s.Require().Equal(int64(450), resp.Summary.TotalTokens)
s.Require().Len(resp.Models, 2)
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats empty")
s.Require().Len(resp.History, 0)
s.Require().Equal(int64(0), resp.Summary.TotalRequests)
}
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetUserUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters time range")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters combined")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}

View File

@@ -0,0 +1,448 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db)
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &model.User{
Email: "create@test.com",
Username: "testuser",
Role: model.RoleUser,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, user)
s.Require().NoError(err, "Create")
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("create@test.com", got.Email)
}
func (s *UserRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, user.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(users, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(12.5, got.Balance)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(7.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(5.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Zero(got.Balance)
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(8, got.Concurrency)
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(3, got.Concurrency)
}
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
s.Require().True(exists)
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
for _, id := range got.AllowedGroups {
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
}
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "admin1@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "admin2@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "user@example.com",
Role: model.RoleUser,
Status: model.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().Error(err, "expected error when no admin exists")
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "disabled@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
Email: "active@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
got.Username = "Alice2"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
got2, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}

View File

@@ -0,0 +1,733 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
suite.Run(t, new(UserSubscriptionRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
sub := &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
err := s.repo.Create(s.ctx, sub)
s.Require().NoError(err, "Create")
s.Require().NotZero(sub.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(sub.UserID, got.UserID)
s.Require().Equal(sub.GroupID, got.GroupID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
})
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.User, "expected User preload")
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().Equal(group.ID, got.Group.ID)
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated notes", got.Notes)
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.Delete(s.ctx, sub.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, sub.ID)
s.Require().Error(err, "expected error after delete")
}
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetByUserIDAndGroupID")
s.Require().Equal(sub.ID, got.ID)
s.Require().NotNil(got.Group, "expected Group preload")
}
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
s.Require().Error(err, "expected error for non-existent pair")
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().Error(err, "expected error for expired subscription")
}
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListByUserID")
s.Require().Len(subs, 2)
for _, sub := range subs {
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(subs, 2)
s.Require().Equal(int64(2), page.Total)
for _, sub := range subs {
s.Require().NotNil(sub.User, "expected User preload")
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
s.Require().NoError(err, "IncrementUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(1.25, got.DailyUsageUSD)
s.Require().Equal(1.25, got.WeeklyUsageUSD)
s.Require().Equal(1.25, got.MonthlyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(3.5, got.DailyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
s.Require().NoError(err, "ActivateWindows")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().NotNil(got.DailyWindowStart)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().True(got.DailyWindowStart.Equal(activateAt))
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
})
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetDailyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.DailyUsageUSD)
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
s.Require().True(got.DailyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
})
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetWeeklyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.WeeklyUsageUSD)
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
})
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetMonthlyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.MonthlyUsageUSD)
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
}
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
s.Require().NoError(err, "ExtendExpiry")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().True(got.ExpiresAt.Equal(newExpiry))
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
s.Require().NoError(err, "UpdateNotes")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal("VIP user", got.Notes)
}
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
expired, err := s.repo.ListExpired(s.ctx)
s.Require().NoError(err, "ListExpired")
s.Require().Len(expired, 1)
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
s.Require().True(exists)
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(2), count)
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
})
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountActiveByGroupID")
s.Require().Equal(int64(1), count, "only future expiry counts as active")
}
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "DeleteByGroupID")
s.Require().Equal(int64(2), affected)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID, "expected active subscription")
activateAt := time.Now().Add(-25 * time.Hour)
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
after, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID after reset")
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
}

View File

@@ -180,6 +180,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
@@ -192,6 +193,8 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)

View File

@@ -388,12 +388,14 @@ func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
"stream": true,
}
// OAuth accounts using ChatGPT internal API require store: false and instructions
// OAuth accounts using ChatGPT internal API require store: false
if isOAuth {
payload["store"] = false
payload["instructions"] = openai.DefaultInstructions
}
// All accounts require instructions for Responses API
payload["instructions"] = openai.DefaultInstructions
return payload
}

View File

@@ -45,6 +45,7 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
@@ -140,6 +141,33 @@ type UpdateAccountInput struct {
GroupIDs *[]int64
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
Name string
Protocol string
@@ -694,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return account, nil
}
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := ports.AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}

View File

@@ -0,0 +1,961 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type CRSSyncService struct {
accountRepo ports.AccountRepository
proxyRepo ports.ProxyRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
}
func NewCRSSyncService(
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
}
}
type SyncFromCRSInput struct {
BaseURL string
Username string
Password string
SyncProxies bool
}
type SyncFromCRSItemResult struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Action string `json:"action"` // created/updated/failed/skipped
Error string `json:"error,omitempty"`
}
type SyncFromCRSResult struct {
Created int `json:"created"`
Updated int `json:"updated"`
Skipped int `json:"skipped"`
Failed int `json:"failed"`
Items []SyncFromCRSItemResult `json:"items"`
}
type crsLoginResponse struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
Error string `json:"error"`
Username string `json:"username"`
}
type crsExportResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
Message string `json:"message"`
Data struct {
ExportedAt string `json:"exportedAt"`
ClaudeAccounts []crsClaudeAccount `json:"claudeAccounts"`
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
} `json:"data"`
}
type crsProxy struct {
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
type crsClaudeAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth/setup-token
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
type crsConsoleAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
MaxConcurrentTasks int `json:"maxConcurrentTasks"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIResponsesAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIOAuthAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
return nil, err
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
client := &http.Client{Timeout: 20 * time.Second}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
if err != nil {
return nil, err
}
now := time.Now().UTC().Format(time.RFC3339)
result := &SyncFromCRSResult{
Items: make(
[]SyncFromCRSItemResult,
0,
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
),
}
var proxies []model.Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
}
// Claude OAuth / Setup Token -> sub2api anthropic oauth/setup-token
for _, src := range exported.Data.ClaudeAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
targetType := strings.TrimSpace(src.AuthType)
if targetType == "" {
targetType = "oauth"
}
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
item.Action = "skipped"
item.Error = "unsupported authType: " + targetType
result.Skipped++
result.Items = append(result.Items, item)
continue
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// 🔧 Remove /v1 suffix from base_url for Claude accounts
cleanBaseURL(credentials, "/v1")
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
// 🔧 Add intercept_warmup_requests if not present (defaults to false)
if _, exists := credentials["intercept_warmup_requests"]; !exists {
credentials["intercept_warmup_requests"] = false
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract org_uuid and account_uuid from CRS credentials to extra
if orgUUID, ok := src.Credentials["org_uuid"]; ok {
extra["org_uuid"] = orgUUID
}
if accountUUID, ok := src.Credentials["account_uuid"]; ok {
extra["account_uuid"] = accountUUID
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: targetType,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
// Update existing
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// Claude Console API Key -> sub2api anthropic apikey
for _, src := range exported.Data.ClaudeConsoleAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
if src.MaxConcurrentTasks > 0 {
concurrency = src.MaxConcurrentTasks
}
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI OAuth -> sub2api openai oauth
for _, src := range exported.Data.OpenAIOAuthAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// Normalize token_type
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
credentials["token_type"] = "Bearer"
}
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract email from CRS extra (crs_email -> email)
if crsEmail, ok := src.Extra["crs_email"]; ok {
extra["email"] = crsEmail
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI Responses API Key -> sub2api openai apikey
for _, src := range exported.Data.OpenAIResponsesAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
if baseURL, ok := src.Credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
src.Credentials["base_url"] = "https://api.openai.com"
}
// 🔧 Remove /v1 suffix from base_url for OpenAI accounts
cleanBaseURL(src.Credentials, "/v1")
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
return result, nil
}
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
out := make(model.JSONB)
for k, v := range existing {
out[k] = v
}
for k, v := range updates {
out[k] = v
}
return out
}
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil {
return nil, nil
}
protocol := strings.ToLower(strings.TrimSpace(src.Protocol))
switch protocol {
case "socks":
protocol = "socks5"
case "socks5h":
protocol = "socks5"
}
host := strings.TrimSpace(src.Host)
port := src.Port
username := strings.TrimSpace(src.Username)
password := strings.TrimSpace(src.Password)
if protocol == "" || host == "" || port <= 0 {
return nil, nil
}
if protocol != "http" && protocol != "https" && protocol != "socks5" {
return nil, nil
}
// Find existing proxy (active only).
for _, p := range *cached {
if strings.EqualFold(p.Protocol, protocol) &&
p.Host == host &&
p.Port == port &&
p.Username == username &&
p.Password == password {
id := p.ID
return &id, nil
}
}
// Create new proxy
proxy := &model.Proxy{
Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol,
Host: host,
Port: port,
Username: username,
Password: password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
*cached = append(*cached, *proxy)
id := proxy.ID
return &id, nil
}
func defaultProxyName(base, protocol, host string, port int) string {
base = strings.TrimSpace(base)
if base == "" {
base = "crs"
}
return fmt.Sprintf("%s (%s://%s:%d)", base, protocol, host, port)
}
func defaultName(name, id string) string {
if strings.TrimSpace(name) != "" {
return strings.TrimSpace(name)
}
return "CRS " + id
}
func clampPriority(priority int) int {
if priority < 1 || priority > 100 {
return 50
}
return priority
}
func sanitizeCredentialsMap(input map[string]any) map[string]any {
if input == nil {
return map[string]any{}
}
out := make(map[string]any, len(input))
for k, v := range input {
// Avoid nil values to keep JSONB cleaner
if v != nil {
out[k] = v
}
}
return out
}
func mapCRSStatus(isActive bool, status string) string {
if !isActive {
return "inactive"
}
if strings.EqualFold(strings.TrimSpace(status), "error") {
return "error"
}
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
// Used for both Claude and OpenAI accounts to remove /v1
func cleanBaseURL(credentials map[string]any, suffixToRemove string) {
if baseURL, ok := credentials["base_url"].(string); ok && baseURL != "" {
trimmed := strings.TrimSpace(baseURL)
if strings.HasSuffix(trimmed, suffixToRemove) {
credentials["base_url"] = strings.TrimSuffix(trimmed, suffixToRemove)
}
}
}
func crsLogin(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
payload := map[string]any{
"username": username,
"password": password,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/web/auth/login", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("crs login failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsLoginResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return "", fmt.Errorf("crs login parse failed: %w", err)
}
if !parsed.Success || strings.TrimSpace(parsed.Token) == "" {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return "", errors.New("crs login failed: " + msg)
}
return parsed.Token, nil
}
func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminToken string) (*crsExportResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/admin/sync/export-accounts?include_secrets=true", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+adminToken)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 5<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("crs export failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsExportResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil, fmt.Errorf("crs export parse failed: %w", err)
}
if !parsed.Success {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return nil, errors.New("crs export failed: " + msg)
}
return &parsed, nil
}
// refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
if account.Type != model.AccountTypeOAuth {
return nil
}
var newCredentials map[string]any
var err error
switch account.Platform {
case model.PlatformAnthropic:
if s.oauthService == nil {
return nil
}
tokenInfo, refreshErr := s.oauthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
// Preserve existing credentials
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
case model.PlatformOpenAI:
if s.openaiOAuthService == nil {
return nil
}
tokenInfo, refreshErr := s.openaiOAuthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
newCredentials = s.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
}
default:
return nil
}
if err != nil {
// Log but don't fail the sync - token might still be valid or refreshable later
return nil
}
return model.JSONB(newCredentials)
}

View File

@@ -20,6 +20,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
)
@@ -360,8 +362,8 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
// 重试相关常量
const (
maxRetries = 3 // 最大重试次数
retryDelay = 2 * time.Second // 重试等待时间
maxRetries = 5 // 最大重试次数
retryDelay = 6 * time.Second // 重试等待时间
)
// shouldRetryUpstreamError 判断是否应该重试上游错误
@@ -390,6 +392,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
return nil, fmt.Errorf("parse request: %w", err)
}
if !gjson.GetBytes(body, "system").Exists() {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
})
}
// 应用模型映射仅对apikey类型账号
originalModel := req.Model
if account.Type == model.AccountTypeApiKey {
@@ -622,6 +636,9 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
var statusCode int
switch resp.StatusCode {
case 400:
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
statusCode = http.StatusBadGateway
errType = "upstream_error"

View File

@@ -827,6 +827,112 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// This fixes the issue where OpenAI's primary/secondary naming is reversed
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var primaryWindowMins, secondaryWindowMins int
var hasPrimaryWindow, hasSecondaryWindow bool
// Only use window_minutes for reliable window size comparison
if snapshot.PrimaryWindowMinutes != nil {
primaryWindowMins = *snapshot.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if snapshot.SecondaryWindowMinutes != nil {
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine which is 5h and which is 7d
var use5hFromPrimary, use7dFromPrimary bool
var use5hFromSecondary, use7dFromSecondary bool
if hasPrimaryWindow && hasSecondaryWindow {
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if primaryWindowMins < secondaryWindowMins {
use5hFromPrimary = true
use7dFromSecondary = true
} else {
use5hFromSecondary = true
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary window size known: classify by absolute threshold
if primaryWindowMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary window size known: classify by absolute threshold
if secondaryWindowMins <= 360 {
use5hFromSecondary = true
} else {
use7dFromSecondary = true
}
} else {
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
use5hFromSecondary = true
}
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
use7dFromPrimary = true
}
}
// Write canonical 5h fields
if use5hFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use5hFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if use7dFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use7dFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
}
// Update account's Extra field asynchronously
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

View File

@@ -11,6 +11,9 @@ import (
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error
@@ -35,4 +38,17 @@ type AccountRepository interface {
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
Status *string
Credentials map[string]any
Extra map[string]any
}

View File

@@ -75,6 +75,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionService,
NewConcurrencyService,
NewIdentityService,
NewCRSSyncService,
ProvideUpdateService,
ProvideTokenRefreshService,

View File

@@ -114,6 +114,8 @@ services:
timeout: 5s
retries: 5
start_period: 10s
ports:
- 5433:5432
# ===========================================================================
# Redis Cache

View File

@@ -3,7 +3,7 @@
* Handles AI platform account management for administrators
*/
import { apiClient } from '../client';
import { apiClient } from '../client'
import type {
Account,
CreateAccountRequest,
@@ -13,7 +13,7 @@ import type {
WindowStats,
ClaudeModel,
AccountUsageStatsResponse,
} from '@/types';
} from '@/types'
/**
* List all accounts with pagination
@@ -26,10 +26,10 @@ export async function list(
page: number = 1,
pageSize: number = 20,
filters?: {
platform?: string;
type?: string;
status?: string;
search?: string;
platform?: string
type?: string
status?: string
search?: string
}
): Promise<PaginatedResponse<Account>> {
const { data } = await apiClient.get<PaginatedResponse<Account>>('/admin/accounts', {
@@ -38,8 +38,8 @@ export async function list(
page_size: pageSize,
...filters,
},
});
return data;
})
return data
}
/**
@@ -48,8 +48,8 @@ export async function list(
* @returns Account details
*/
export async function getById(id: number): Promise<Account> {
const { data } = await apiClient.get<Account>(`/admin/accounts/${id}`);
return data;
const { data } = await apiClient.get<Account>(`/admin/accounts/${id}`)
return data
}
/**
@@ -58,8 +58,8 @@ export async function getById(id: number): Promise<Account> {
* @returns Created account
*/
export async function create(accountData: CreateAccountRequest): Promise<Account> {
const { data } = await apiClient.post<Account>('/admin/accounts', accountData);
return data;
const { data } = await apiClient.post<Account>('/admin/accounts', accountData)
return data
}
/**
@@ -69,8 +69,8 @@ export async function create(accountData: CreateAccountRequest): Promise<Account
* @returns Updated account
*/
export async function update(id: number, updates: UpdateAccountRequest): Promise<Account> {
const { data } = await apiClient.put<Account>(`/admin/accounts/${id}`, updates);
return data;
const { data } = await apiClient.put<Account>(`/admin/accounts/${id}`, updates)
return data
}
/**
@@ -79,8 +79,8 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
* @returns Success confirmation
*/
export async function deleteAccount(id: number): Promise<{ message: string }> {
const { data } = await apiClient.delete<{ message: string }>(`/admin/accounts/${id}`);
return data;
const { data } = await apiClient.delete<{ message: string }>(`/admin/accounts/${id}`)
return data
}
/**
@@ -89,11 +89,8 @@ export async function deleteAccount(id: number): Promise<{ message: string }> {
* @param status - New status
* @returns Updated account
*/
export async function toggleStatus(
id: number,
status: 'active' | 'inactive'
): Promise<Account> {
return update(id, { status });
export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise<Account> {
return update(id, { status })
}
/**
@@ -102,16 +99,16 @@ export async function toggleStatus(
* @returns Test result
*/
export async function testAccount(id: number): Promise<{
success: boolean;
message: string;
latency_ms?: number;
success: boolean
message: string
latency_ms?: number
}> {
const { data } = await apiClient.post<{
success: boolean;
message: string;
latency_ms?: number;
}>(`/admin/accounts/${id}/test`);
return data;
success: boolean
message: string
latency_ms?: number
}>(`/admin/accounts/${id}/test`)
return data
}
/**
@@ -120,8 +117,8 @@ export async function testAccount(id: number): Promise<{
* @returns Updated account
*/
export async function refreshCredentials(id: number): Promise<Account> {
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/refresh`);
return data;
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/refresh`)
return data
}
/**
@@ -133,8 +130,8 @@ export async function refreshCredentials(id: number): Promise<Account> {
export async function getStats(id: number, days: number = 30): Promise<AccountUsageStatsResponse> {
const { data } = await apiClient.get<AccountUsageStatsResponse>(`/admin/accounts/${id}/stats`, {
params: { days },
});
return data;
})
return data
}
/**
@@ -143,8 +140,8 @@ export async function getStats(id: number, days: number = 30): Promise<AccountUs
* @returns Updated account
*/
export async function clearError(id: number): Promise<Account> {
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/clear-error`);
return data;
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/clear-error`)
return data
}
/**
@@ -153,8 +150,8 @@ export async function clearError(id: number): Promise<Account> {
* @returns Account usage info
*/
export async function getUsage(id: number): Promise<AccountUsageInfo> {
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`);
return data;
const { data } = await apiClient.get<AccountUsageInfo>(`/admin/accounts/${id}/usage`)
return data
}
/**
@@ -163,8 +160,10 @@ export async function getUsage(id: number): Promise<AccountUsageInfo> {
* @returns Success confirmation
*/
export async function clearRateLimit(id: number): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(`/admin/accounts/${id}/clear-rate-limit`);
return data;
const { data } = await apiClient.post<{ message: string }>(
`/admin/accounts/${id}/clear-rate-limit`
)
return data
}
/**
@@ -177,8 +176,8 @@ export async function generateAuthUrl(
endpoint: string,
config: { proxy_id?: number }
): Promise<{ auth_url: string; session_id: string }> {
const { data } = await apiClient.post<{ auth_url: string; session_id: string }>(endpoint, config);
return data;
const { data } = await apiClient.post<{ auth_url: string; session_id: string }>(endpoint, config)
return data
}
/**
@@ -191,8 +190,8 @@ export async function exchangeCode(
endpoint: string,
exchangeData: { session_id: string; code: string; proxy_id?: number }
): Promise<Record<string, unknown>> {
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData);
return data;
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
return data
}
/**
@@ -201,16 +200,63 @@ export async function exchangeCode(
* @returns Results of batch creation
*/
export async function batchCreate(accounts: CreateAccountRequest[]): Promise<{
success: number;
failed: number;
results: Array<{ success: boolean; account?: Account; error?: string }>;
success: number
failed: number
results: Array<{ success: boolean; account?: Account; error?: string }>
}> {
const { data } = await apiClient.post<{
success: number;
failed: number;
results: Array<{ success: boolean; account?: Account; error?: string }>;
}>('/admin/accounts/batch', { accounts });
return data;
success: number
failed: number
results: Array<{ success: boolean; account?: Account; error?: string }>
}>('/admin/accounts/batch', { accounts })
return data
}
/**
* Batch update credentials fields for multiple accounts
* @param request - Batch update request containing account IDs, field name, and value
* @returns Results of batch update
*/
export async function batchUpdateCredentials(request: {
account_ids: number[]
field: string
value: any
}): Promise<{
success: number
failed: number
results: Array<{ account_id: number; success: boolean; error?: string }>
}> {
const { data } = await apiClient.post<{
success: number
failed: number
results: Array<{ account_id: number; success: boolean; error?: string }>
}>('/admin/accounts/batch-update-credentials', request)
return data
}
/**
* Bulk update multiple accounts
* @param accountIds - Array of account IDs
* @param updates - Fields to update
* @returns Success confirmation
*/
export async function bulkUpdate(
accountIds: number[],
updates: Record<string, unknown>
): Promise<{
success: number
failed: number
results: Array<{ account_id: number; success: boolean; error?: string }>
}> {
const { data } = await apiClient.post<{
success: number
failed: number
results: Array<{ account_id: number; success: boolean; error?: string }>
}>('/admin/accounts/bulk-update', {
account_ids: accountIds,
...updates,
})
return data
}
/**
@@ -219,8 +265,8 @@ export async function batchCreate(accounts: CreateAccountRequest[]): Promise<{
* @returns Today's stats (requests, tokens, cost)
*/
export async function getTodayStats(id: number): Promise<WindowStats> {
const { data } = await apiClient.get<WindowStats>(`/admin/accounts/${id}/today-stats`);
return data;
const { data } = await apiClient.get<WindowStats>(`/admin/accounts/${id}/today-stats`)
return data
}
/**
@@ -230,8 +276,10 @@ export async function getTodayStats(id: number): Promise<WindowStats> {
* @returns Updated account
*/
export async function setSchedulable(id: number, schedulable: boolean): Promise<Account> {
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/schedulable`, { schedulable });
return data;
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/schedulable`, {
schedulable,
})
return data
}
/**
@@ -240,8 +288,30 @@ export async function setSchedulable(id: number, schedulable: boolean): Promise<
* @returns List of available models for this account
*/
export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`);
return data;
const { data } = await apiClient.get<ClaudeModel[]>(`/admin/accounts/${id}/models`)
return data
}
export async function syncFromCrs(params: {
base_url: string
username: string
password: string
sync_proxies?: boolean
}): Promise<{
created: number
updated: number
skipped: number
failed: number
items: Array<{
crs_account_id: string
kind: string
name: string
action: string
error?: string
}>
}> {
const { data } = await apiClient.post('/admin/accounts/sync/crs', params)
return data
}
export const accountsAPI = {
@@ -263,6 +333,9 @@ export const accountsAPI = {
generateAuthUrl,
exchangeCode,
batchCreate,
};
batchUpdateCredentials,
bulkUpdate,
syncFromCrs,
}
export default accountsAPI;
export default accountsAPI

View File

@@ -69,21 +69,21 @@
<!-- OpenAI OAuth accounts: show Codex usage from extra field -->
<template v-else-if="account.platform === 'openai' && account.type === 'oauth'">
<div v-if="hasCodexUsage" class="space-y-1">
<!-- 5h Window (Secondary) -->
<!-- 5h Window -->
<UsageProgressBar
v-if="codexSecondaryUsedPercent !== null"
v-if="codex5hUsedPercent !== null"
label="5h"
:utilization="codexSecondaryUsedPercent"
:resets-at="codexSecondaryResetAt"
:utilization="codex5hUsedPercent"
:resets-at="codex5hResetAt"
color="indigo"
/>
<!-- Weekly Window (Primary) -->
<!-- 7d Window -->
<UsageProgressBar
v-if="codexPrimaryUsedPercent !== null"
v-if="codex7dUsedPercent !== null"
label="7d"
:utilization="codexPrimaryUsedPercent"
:resets-at="codexPrimaryResetAt"
:utilization="codex7dUsedPercent"
:resets-at="codex7dResetAt"
color="emerald"
/>
</div>
@@ -125,35 +125,123 @@ const showUsageWindows = computed(() =>
const hasCodexUsage = computed(() => {
const extra = props.account.extra
return extra && (
// Check for new canonical fields first
extra.codex_5h_used_percent !== undefined ||
extra.codex_7d_used_percent !== undefined ||
// Fallback to legacy fields
extra.codex_primary_used_percent !== undefined ||
extra.codex_secondary_used_percent !== undefined
)
})
const codexPrimaryUsedPercent = computed(() => {
// 5h window usage (prefer canonical field)
const codex5hUsedPercent = computed(() => {
const extra = props.account.extra
if (!extra || extra.codex_primary_used_percent === undefined) return null
return extra.codex_primary_used_percent
if (!extra) return null
// Prefer canonical field
if (extra.codex_5h_used_percent !== undefined) {
return extra.codex_5h_used_percent
}
// Fallback: detect from legacy fields using window_minutes
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes <= 360) {
return extra.codex_primary_used_percent ?? null
}
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes <= 360) {
return extra.codex_secondary_used_percent ?? null
}
// Legacy assumption: secondary = 5h (may be incorrect)
return extra.codex_secondary_used_percent ?? null
})
const codexSecondaryUsedPercent = computed(() => {
const codex5hResetAt = computed(() => {
const extra = props.account.extra
if (!extra || extra.codex_secondary_used_percent === undefined) return null
return extra.codex_secondary_used_percent
if (!extra) return null
// Prefer canonical field
if (extra.codex_5h_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_5h_reset_after_seconds * 1000)
return resetTime.toISOString()
}
// Fallback: detect from legacy fields using window_minutes
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes <= 360) {
if (extra.codex_primary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
}
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes <= 360) {
if (extra.codex_secondary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
}
// Legacy assumption: secondary = 5h
if (extra.codex_secondary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
return null
})
const codexPrimaryResetAt = computed(() => {
// 7d window usage (prefer canonical field)
const codex7dUsedPercent = computed(() => {
const extra = props.account.extra
if (!extra || extra.codex_primary_reset_after_seconds === undefined) return null
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
return resetTime.toISOString()
if (!extra) return null
// Prefer canonical field
if (extra.codex_7d_used_percent !== undefined) {
return extra.codex_7d_used_percent
}
// Fallback: detect from legacy fields using window_minutes
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes >= 10000) {
return extra.codex_primary_used_percent ?? null
}
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes >= 10000) {
return extra.codex_secondary_used_percent ?? null
}
// Legacy assumption: primary = 7d (may be incorrect)
return extra.codex_primary_used_percent ?? null
})
const codexSecondaryResetAt = computed(() => {
const codex7dResetAt = computed(() => {
const extra = props.account.extra
if (!extra || extra.codex_secondary_reset_after_seconds === undefined) return null
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
return resetTime.toISOString()
if (!extra) return null
// Prefer canonical field
if (extra.codex_7d_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_7d_reset_after_seconds * 1000)
return resetTime.toISOString()
}
// Fallback: detect from legacy fields using window_minutes
if (extra.codex_primary_window_minutes !== undefined && extra.codex_primary_window_minutes >= 10000) {
if (extra.codex_primary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
}
if (extra.codex_secondary_window_minutes !== undefined && extra.codex_secondary_window_minutes >= 10000) {
if (extra.codex_secondary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_secondary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
}
// Legacy assumption: primary = 7d
if (extra.codex_primary_reset_after_seconds !== undefined) {
const resetTime = new Date(Date.now() + extra.codex_primary_reset_after_seconds * 1000)
return resetTime.toISOString()
}
return null
})
const loadUsage = async () => {

View File

@@ -0,0 +1,892 @@
<template>
<Modal :show="show" :title="t('admin.accounts.bulkEdit.title')" size="lg" @close="handleClose">
<form class="space-y-5" @submit.prevent="handleSubmit">
<!-- Info -->
<div class="rounded-lg bg-blue-50 dark:bg-blue-900/20 p-4">
<p class="text-sm text-blue-700 dark:text-blue-400">
<svg class="w-5 h-5 inline mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
</p>
</div>
<!-- Base URL (API Key only) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('admin.accounts.baseUrl') }}</label>
<input
v-model="enableBaseUrl"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model="baseUrl"
type="text"
:disabled="!enableBaseUrl"
class="input"
:class="!enableBaseUrl && 'opacity-50 cursor-not-allowed'"
:placeholder="t('admin.accounts.bulkEdit.baseUrlPlaceholder')"
/>
<p class="input-hint">
{{ t('admin.accounts.bulkEdit.baseUrlNotice') }}
</p>
</div>
<!-- Model restriction -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('admin.accounts.modelRestriction') }}</label>
<input
v-model="enableModelRestriction"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div :class="!enableModelRestriction && 'opacity-50 pointer-events-none'">
<!-- Mode Toggle -->
<div class="flex gap-2 mb-4">
<button
type="button"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'whitelist'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
]"
@click="modelRestrictionMode = 'whitelist'"
>
<svg
class="w-4 h-4 inline mr-1.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.modelWhitelist') }}
</button>
<button
type="button"
:class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
]"
@click="modelRestrictionMode = 'mapping'"
>
<svg
class="w-4 h-4 inline mr-1.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M8 7h12m0 0l-4-4m4 4l-4 4m0 6H4m0 0l4 4m-4-4l4-4"
/>
</svg>
{{ t('admin.accounts.modelMapping') }}
</button>
</div>
<!-- Whitelist Mode -->
<div v-if="modelRestrictionMode === 'whitelist'">
<div class="mb-3 rounded-lg bg-blue-50 dark:bg-blue-900/20 p-3">
<p class="text-xs text-blue-700 dark:text-blue-400">
<svg
class="w-4 h-4 inline mr-1"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.selectAllowedModels') }}
</p>
</div>
<!-- Model Checkbox List -->
<div class="grid grid-cols-2 gap-2 mb-3">
<label
v-for="model in allModels"
:key="model.value"
class="flex cursor-pointer items-center rounded-lg border p-3 transition-all hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700"
:class="
allowedModels.includes(model.value)
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
: 'border-gray-200'
"
>
<input
v-model="allowedModels"
type="checkbox"
:value="model.value"
class="mr-2 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ model.label }}</span>
</label>
</div>
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{
t('admin.accounts.supportsAllModels')
}}</span>
</p>
</div>
<!-- Mapping Mode -->
<div v-else>
<div class="mb-3 rounded-lg bg-purple-50 dark:bg-purple-900/20 p-3">
<p class="text-xs text-purple-700 dark:text-purple-400">
<svg
class="w-4 h-4 inline mr-1"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
{{ t('admin.accounts.mapRequestModels') }}
</p>
</div>
<!-- Model Mapping List -->
<div v-if="modelMappings.length > 0" class="space-y-2 mb-3">
<div
v-for="(mapping, index) in modelMappings"
:key="index"
class="flex items-center gap-2"
>
<input
v-model="mapping.from"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.requestModel')"
/>
<svg
class="w-4 h-4 text-gray-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M14 5l7 7m0 0l-7 7m7-7H3"
/>
</svg>
<input
v-model="mapping.to"
type="text"
class="input flex-1"
:placeholder="t('admin.accounts.actualModel')"
/>
<button
type="button"
class="p-2 text-red-500 hover:text-red-600 hover:bg-red-50 dark:hover:bg-red-900/20 rounded-lg transition-colors"
@click="removeModelMapping(index)"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
</div>
</div>
<button
type="button"
class="w-full rounded-lg border-2 border-dashed border-gray-300 dark:border-dark-500 px-4 py-2 text-gray-600 dark:text-gray-400 transition-colors hover:border-gray-400 hover:text-gray-700 dark:hover:border-dark-400 dark:hover:text-gray-300 mb-3"
@click="addModelMapping"
>
<svg
class="w-4 h-4 inline mr-1"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t('admin.accounts.addMapping') }}
</button>
<!-- Quick Add Buttons -->
<div class="flex flex-wrap gap-2">
<button
v-for="preset in presetMappings"
:key="preset.label"
type="button"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
@click="addPresetMapping(preset.from, preset.to)"
>
+ {{ preset.label }}
</button>
</div>
</div>
</div>
</div>
<!-- Custom error codes -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">
{{ t('admin.accounts.customErrorCodesHint') }}
</p>
</div>
<input
v-model="enableCustomErrorCodes"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div v-if="enableCustomErrorCodes" class="space-y-3">
<div class="rounded-lg bg-amber-50 dark:bg-amber-900/20 p-3">
<p class="text-xs text-amber-700 dark:text-amber-400">
<svg
class="w-4 h-4 inline mr-1"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
{{ t('admin.accounts.customErrorCodesWarning') }}
</p>
</div>
<!-- Error Code Buttons -->
<div class="flex flex-wrap gap-2">
<button
v-for="code in commonErrorCodes"
:key="code.value"
type="button"
:class="[
'rounded-lg px-3 py-1.5 text-sm font-medium transition-colors',
selectedErrorCodes.includes(code.value)
? 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400 ring-1 ring-red-500'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500',
]"
@click="toggleErrorCode(code.value)"
>
{{ code.value }} {{ code.label }}
</button>
</div>
<!-- Manual input -->
<div class="flex items-center gap-2">
<input
v-model="customErrorCodeInput"
type="number"
min="100"
max="599"
class="input flex-1"
:placeholder="t('admin.accounts.enterErrorCode')"
@keyup.enter="addCustomErrorCode"
/>
<button type="button" class="btn btn-secondary px-3" @click="addCustomErrorCode">
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M12 4v16m8-8H4"
/>
</svg>
</button>
</div>
<!-- Selected codes summary -->
<div class="flex flex-wrap gap-1.5">
<span
v-for="code in selectedErrorCodes.sort((a, b) => a - b)"
:key="code"
class="inline-flex items-center gap-1 rounded-full bg-red-100 dark:bg-red-900/30 px-2.5 py-0.5 text-sm font-medium text-red-700 dark:text-red-400"
>
{{ code }}
<button
type="button"
class="hover:text-red-900 dark:hover:text-red-300"
@click="removeErrorCode(code)"
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</span>
<span v-if="selectedErrorCodes.length === 0" class="text-xs text-gray-400">
{{ t('admin.accounts.noneSelectedUsesDefault') }}
</span>
</div>
</div>
</div>
<!-- Intercept warmup requests (Anthropic only) -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between">
<div class="flex-1 pr-4">
<label class="input-label mb-0">{{
t('admin.accounts.interceptWarmupRequests')
}}</label>
<p class="text-xs text-gray-500 dark:text-gray-400 mt-1">
{{ t('admin.accounts.interceptWarmupRequestsDesc') }}
</p>
</div>
<input
v-model="enableInterceptWarmup"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div v-if="enableInterceptWarmup" class="mt-3">
<button
type="button"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
interceptWarmupRequests ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600',
]"
@click="interceptWarmupRequests = !interceptWarmupRequests"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
interceptWarmupRequests ? 'translate-x-5' : 'translate-x-0',
]"
/>
</button>
</div>
</div>
<!-- Proxy -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('admin.accounts.proxy') }}</label>
<input
v-model="enableProxy"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div :class="!enableProxy && 'opacity-50 pointer-events-none'">
<ProxySelector v-model="proxyId" :proxies="proxies" />
</div>
</div>
<!-- Concurrency & Priority -->
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 dark:border-dark-600 pt-4">
<div>
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('admin.accounts.concurrency') }}</label>
<input
v-model="enableConcurrency"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model.number="concurrency"
type="number"
min="1"
:disabled="!enableConcurrency"
class="input"
:class="!enableConcurrency && 'opacity-50 cursor-not-allowed'"
/>
</div>
<div>
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('admin.accounts.priority') }}</label>
<input
v-model="enablePriority"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model.number="priority"
type="number"
min="1"
:disabled="!enablePriority"
class="input"
:class="!enablePriority && 'opacity-50 cursor-not-allowed'"
/>
</div>
</div>
<!-- Status -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('common.status') }}</label>
<input
v-model="enableStatus"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div :class="!enableStatus && 'opacity-50 pointer-events-none'">
<Select v-model="status" :options="statusOptions" />
</div>
</div>
<!-- Groups -->
<div class="border-t border-gray-200 dark:border-dark-600 pt-4">
<div class="flex items-center justify-between mb-3">
<label class="input-label mb-0">{{ t('nav.groups') }}</label>
<input
v-model="enableGroups"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div :class="!enableGroups && 'opacity-50 pointer-events-none'">
<GroupSelector v-model="groupIds" :groups="groups" />
</div>
</div>
<!-- Action buttons -->
<div class="flex justify-end gap-3 pt-4">
<button type="button" class="btn btn-secondary" @click="handleClose">
{{ t('common.cancel') }}
</button>
<button type="submit" :disabled="submitting" class="btn btn-primary">
<svg
v-if="submitting"
class="animate-spin -ml-1 mr-2 h-4 w-4"
fill="none"
viewBox="0 0 24 24"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
/>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
/>
</svg>
{{
submitting ? t('admin.accounts.bulkEdit.updating') : t('admin.accounts.bulkEdit.submit')
}}
</button>
</div>
</form>
</Modal>
</template>
<script setup lang="ts">
import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { Proxy, Group } from '@/types'
import Modal from '@/components/common/Modal.vue'
import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
interface Props {
show: boolean
accountIds: number[]
proxies: Proxy[]
groups: Group[]
}
const props = defineProps<Props>()
const emit = defineEmits<{
close: []
updated: []
}>()
const { t } = useI18n()
const appStore = useAppStore()
// Model mapping type
interface ModelMapping {
from: string
to: string
}
// State - field enable flags
const enableBaseUrl = ref(false)
const enableModelRestriction = ref(false)
const enableCustomErrorCodes = ref(false)
const enableInterceptWarmup = ref(false)
const enableProxy = ref(false)
const enableConcurrency = ref(false)
const enablePriority = ref(false)
const enableStatus = ref(false)
const enableGroups = ref(false)
// State - field values
const submitting = ref(false)
const baseUrl = ref('')
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
const modelMappings = ref<ModelMapping[]>([])
const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
const proxyId = ref<number | null>(null)
const concurrency = ref(1)
const priority = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
// All models list (combined Anthropic + OpenAI)
const allModels = [
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
{ value: 'claude-3-5-haiku-20241022', label: 'Claude 3.5 Haiku' },
{ value: 'claude-haiku-4-5-20251001', label: 'Claude Haiku 4.5' },
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
]
// Preset mappings (combined Anthropic + OpenAI)
const presetMappings = [
{
label: 'Sonnet 4',
from: 'claude-sonnet-4-20250514',
to: 'claude-sonnet-4-20250514',
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400',
},
{
label: 'Sonnet 4.5',
from: 'claude-sonnet-4-5-20250929',
to: 'claude-sonnet-4-5-20250929',
color:
'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400',
},
{
label: 'Opus 4.5',
from: 'claude-opus-4-5-20251101',
to: 'claude-opus-4-5-20251101',
color:
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400',
},
{
label: 'Opus->Sonnet',
from: 'claude-opus-4-5-20251101',
to: 'claude-sonnet-4-5-20250929',
color:
'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400',
},
{
label: 'GPT-5.2',
from: 'gpt-5.2-2025-12-11',
to: 'gpt-5.2-2025-12-11',
color:
'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400',
},
{
label: 'GPT-5.2 Codex',
from: 'gpt-5.2-codex',
to: 'gpt-5.2-codex',
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400',
},
{
label: 'Max->Codex',
from: 'gpt-5.1-codex-max',
to: 'gpt-5.1-codex',
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400',
},
]
// Common HTTP error codes
const commonErrorCodes = [
{ value: 401, label: 'Unauthorized' },
{ value: 403, label: 'Forbidden' },
{ value: 429, label: 'Rate Limit' },
{ value: 500, label: 'Server Error' },
{ value: 502, label: 'Bad Gateway' },
{ value: 503, label: 'Unavailable' },
{ value: 529, label: 'Overloaded' },
]
const statusOptions = computed(() => [
{ value: 'active', label: t('common.active') },
{ value: 'inactive', label: t('common.inactive') },
])
// Model mapping helpers
const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
}
const removeModelMapping = (index: number) => {
modelMappings.value.splice(index, 1)
}
const addPresetMapping = (from: string, to: string) => {
const exists = modelMappings.value.some((m) => m.from === from)
if (exists) {
appStore.showInfo(t('admin.accounts.mappingExists', { model: from }))
return
}
modelMappings.value.push({ from, to })
}
// Error code helpers
const toggleErrorCode = (code: number) => {
const index = selectedErrorCodes.value.indexOf(code)
if (index === -1) {
selectedErrorCodes.value.push(code)
} else {
selectedErrorCodes.value.splice(index, 1)
}
}
const addCustomErrorCode = () => {
const code = customErrorCodeInput.value
if (code === null || code < 100 || code > 599) {
appStore.showError(t('admin.accounts.invalidErrorCode'))
return
}
if (selectedErrorCodes.value.includes(code)) {
appStore.showInfo(t('admin.accounts.errorCodeExists'))
return
}
selectedErrorCodes.value.push(code)
customErrorCodeInput.value = null
}
const removeErrorCode = (code: number) => {
const index = selectedErrorCodes.value.indexOf(code)
if (index !== -1) {
selectedErrorCodes.value.splice(index, 1)
}
}
const buildModelMappingObject = (): Record<string, string> | null => {
const mapping: Record<string, string> = {}
if (modelRestrictionMode.value === 'whitelist') {
for (const model of allowedModels.value) {
mapping[model] = model
}
} else {
for (const m of modelMappings.value) {
const from = m.from.trim()
const to = m.to.trim()
if (from && to) {
mapping[from] = to
}
}
}
return Object.keys(mapping).length > 0 ? mapping : null
}
const buildUpdatePayload = (): Record<string, unknown> | null => {
const updates: Record<string, unknown> = {}
const credentials: Record<string, unknown> = {}
let credentialsChanged = false
if (enableProxy.value) {
updates.proxy_id = proxyId.value
}
if (enableConcurrency.value) {
updates.concurrency = concurrency.value
}
if (enablePriority.value) {
updates.priority = priority.value
}
if (enableStatus.value) {
updates.status = status.value
}
if (enableGroups.value) {
updates.group_ids = groupIds.value
}
if (enableBaseUrl.value) {
const baseUrlValue = baseUrl.value.trim()
if (baseUrlValue) {
credentials.base_url = baseUrlValue
credentialsChanged = true
}
}
if (enableModelRestriction.value) {
const modelMapping = buildModelMappingObject()
if (modelMapping) {
credentials.model_mapping = modelMapping
credentialsChanged = true
}
}
if (enableCustomErrorCodes.value) {
credentials.custom_error_codes_enabled = true
credentials.custom_error_codes = [...selectedErrorCodes.value]
credentialsChanged = true
}
if (enableInterceptWarmup.value) {
credentials.intercept_warmup_requests = interceptWarmupRequests.value
credentialsChanged = true
}
if (credentialsChanged) {
updates.credentials = credentials
}
return Object.keys(updates).length > 0 ? updates : null
}
const handleClose = () => {
emit('close')
}
const handleSubmit = async () => {
if (props.accountIds.length === 0) {
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
return
}
const hasAnyFieldEnabled =
enableBaseUrl.value ||
enableModelRestriction.value ||
enableCustomErrorCodes.value ||
enableInterceptWarmup.value ||
enableProxy.value ||
enableConcurrency.value ||
enablePriority.value ||
enableStatus.value ||
enableGroups.value
if (!hasAnyFieldEnabled) {
appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected'))
return
}
const updates = buildUpdatePayload()
if (!updates) {
appStore.showError(t('admin.accounts.bulkEdit.noFieldsSelected'))
return
}
submitting.value = true
try {
const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
const success = res.success || 0
const failed = res.failed || 0
if (success > 0 && failed === 0) {
appStore.showSuccess(t('admin.accounts.bulkEdit.success', { count: success }))
} else if (success > 0) {
appStore.showError(t('admin.accounts.bulkEdit.partialSuccess', { success, failed }))
} else {
appStore.showError(t('admin.accounts.bulkEdit.failed'))
}
if (success > 0) {
emit('updated')
handleClose()
}
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.bulkEdit.failed'))
console.error('Error bulk updating accounts:', error)
} finally {
submitting.value = false
}
}
// Reset form when modal closes
watch(
() => props.show,
(newShow) => {
if (!newShow) {
// Reset all enable flags
enableBaseUrl.value = false
enableModelRestriction.value = false
enableCustomErrorCodes.value = false
enableInterceptWarmup.value = false
enableProxy.value = false
enableConcurrency.value = false
enablePriority.value = false
enableStatus.value = false
enableGroups.value = false
// Reset all values
baseUrl.value = ''
modelRestrictionMode.value = 'whitelist'
allowedModels.value = []
modelMappings.value = []
selectedErrorCodes.value = []
customErrorCodeInput.value = null
interceptWarmupRequests.value = false
proxyId.value = null
concurrency.value = 1
priority.value = 1
status.value = 'active'
groupIds.value = []
}
}
)
</script>

View File

@@ -0,0 +1,170 @@
<template>
<Modal
:show="show"
:title="t('admin.accounts.syncFromCrsTitle')"
size="lg"
close-on-click-outside
@close="handleClose"
>
<div class="space-y-4">
<div class="text-sm text-gray-600 dark:text-dark-300">
{{ t('admin.accounts.syncFromCrsDesc') }}
</div>
<div class="text-xs text-gray-500 dark:text-dark-400 bg-gray-50 dark:bg-dark-700/60 rounded-lg p-3">
已有账号仅同步 CRS 返回的字段缺失字段保持原值凭据按键合并不会清空未下发的键未勾选"同步代理"时保留原有代理
</div>
<div class="text-xs text-amber-600 dark:text-amber-400 bg-amber-50 dark:bg-amber-900/20 border border-amber-200 dark:border-amber-800 rounded-lg p-3">
{{ t('admin.accounts.crsVersionRequirement') }}
</div>
<div class="grid grid-cols-1 gap-4">
<div>
<label class="input-label">{{ t('admin.accounts.crsBaseUrl') }}</label>
<input
v-model="form.base_url"
type="text"
class="input"
:placeholder="t('admin.accounts.crsBaseUrlPlaceholder')"
/>
</div>
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div>
<label class="input-label">{{ t('admin.accounts.crsUsername') }}</label>
<input
v-model="form.username"
type="text"
class="input"
autocomplete="username"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.crsPassword') }}</label>
<input
v-model="form.password"
type="password"
class="input"
autocomplete="current-password"
/>
</div>
</div>
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-dark-300">
<input v-model="form.sync_proxies" type="checkbox" class="rounded border-gray-300 dark:border-dark-600" />
{{ t('admin.accounts.syncProxies') }}
</label>
</div>
<div v-if="result" class="rounded-xl border border-gray-200 dark:border-dark-700 p-4 space-y-2">
<div class="text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.syncResult') }}
</div>
<div class="text-sm text-gray-700 dark:text-dark-300">
{{ t('admin.accounts.syncResultSummary', result) }}
</div>
<div v-if="errorItems.length" class="mt-2">
<div class="text-sm font-medium text-red-600 dark:text-red-400">
{{ t('admin.accounts.syncErrors') }}
</div>
<div class="mt-2 max-h-48 overflow-auto rounded-lg bg-gray-50 dark:bg-dark-800 p-3 text-xs font-mono">
<div v-for="(item, idx) in errorItems" :key="idx" class="whitespace-pre-wrap">
{{ item.kind }} {{ item.crs_account_id }} {{ item.action }}{{ item.error ? `: ${item.error}` : '' }}
</div>
</div>
</div>
</div>
</div>
<template #footer>
<div class="flex justify-end gap-3">
<button class="btn btn-secondary" :disabled="syncing" @click="handleClose">
{{ t('common.cancel') }}
</button>
<button class="btn btn-primary" :disabled="syncing" @click="handleSync">
{{ syncing ? t('admin.accounts.syncing') : t('admin.accounts.syncNow') }}
</button>
</div>
</template>
</Modal>
</template>
<script setup lang="ts">
import { computed, reactive, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import Modal from '@/components/common/Modal.vue'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
interface Props {
show: boolean
}
interface Emits {
(e: 'close'): void
(e: 'synced'): void
}
const props = defineProps<Props>()
const emit = defineEmits<Emits>()
const { t } = useI18n()
const appStore = useAppStore()
const syncing = ref(false)
const result = ref<Awaited<ReturnType<typeof adminAPI.accounts.syncFromCrs>> | null>(null)
const form = reactive({
base_url: '',
username: '',
password: '',
sync_proxies: true
})
const errorItems = computed(() => {
if (!result.value?.items) return []
return result.value.items.filter((i) => i.action === 'failed' || i.action === 'skipped')
})
watch(
() => props.show,
(open) => {
if (open) {
result.value = null
}
}
)
const handleClose = () => {
emit('close')
}
const handleSync = async () => {
if (!form.base_url.trim() || !form.username.trim() || !form.password.trim()) {
appStore.showError(t('admin.accounts.syncMissingFields'))
return
}
syncing.value = true
try {
const res = await adminAPI.accounts.syncFromCrs({
base_url: form.base_url.trim(),
username: form.username.trim(),
password: form.password,
sync_proxies: form.sync_proxies
})
result.value = res
if (res.failed > 0) {
appStore.showError(t('admin.accounts.syncCompletedWithErrors', res))
} else {
appStore.showSuccess(t('admin.accounts.syncCompleted', res))
emit('synced')
}
} catch (error: any) {
appStore.showError(error?.message || t('admin.accounts.syncFailed'))
} finally {
syncing.value = false
}
}
</script>

View File

@@ -1,5 +1,6 @@
export { default as CreateAccountModal } from './CreateAccountModal.vue'
export { default as EditAccountModal } from './EditAccountModal.vue'
export { default as BulkEditAccountModal } from './BulkEditAccountModal.vue'
export { default as ReAuthAccountModal } from './ReAuthAccountModal.vue'
export { default as OAuthAuthorizationFlow } from './OAuthAuthorizationFlow.vue'
export { default as AccountStatusIndicator } from './AccountStatusIndicator.vue'
@@ -8,3 +9,4 @@ export { default as UsageProgressBar } from './UsageProgressBar.vue'
export { default as AccountStatsModal } from './AccountStatsModal.vue'
export { default as AccountTestModal } from './AccountTestModal.vue'
export { default as AccountTodayStatsCell } from './AccountTodayStatsCell.vue'
export { default as SyncFromCrsModal } from './SyncFromCrsModal.vue'

View File

@@ -17,11 +17,14 @@ export default {
},
features: {
unifiedGateway: 'Unified API Gateway',
unifiedGatewayDesc: 'Convert Claude subscriptions to API endpoints. Access AI capabilities through standard /v1/messages interface.',
unifiedGatewayDesc:
'Convert Claude subscriptions to API endpoints. Access AI capabilities through standard /v1/messages interface.',
multiAccount: 'Multi-Account Pool',
multiAccountDesc: 'Manage multiple upstream accounts with smart load balancing. Support OAuth and API Key authentication.',
multiAccountDesc:
'Manage multiple upstream accounts with smart load balancing. Support OAuth and API Key authentication.',
balanceQuota: 'Balance & Quota',
balanceQuotaDesc: 'Token-based billing with precise usage tracking. Manage quotas and recharge with redeem codes.',
balanceQuotaDesc:
'Token-based billing with precise usage tracking. Manage quotas and recharge with redeem codes.',
},
providers: {
title: 'Supported Providers',
@@ -235,7 +238,8 @@ export default {
useKey: 'Use Key',
useKeyModal: {
title: 'Use API Key',
description: 'Add the following environment variables to your terminal profile or run directly in terminal to configure API access.',
description:
'Add the following environment variables to your terminal profile or run directly in terminal to configure API access.',
copy: 'Copy',
copied: 'Copied',
note: 'These environment variables will be active in the current terminal session. For permanent configuration, add them to ~/.bashrc, ~/.zshrc, or the appropriate configuration file.',
@@ -517,7 +521,8 @@ export default {
failedToLoadApiKeys: 'Failed to load user API keys',
deleteConfirm: "Are you sure you want to delete '{email}'? This action cannot be undone.",
setAllowedGroups: 'Set Allowed Groups',
allowedGroupsHint: 'Select which standard groups this user can use. Subscription groups are managed separately.',
allowedGroupsHint:
'Select which standard groups this user can use. Subscription groups are managed separately.',
noStandardGroups: 'No standard groups available',
allowAllGroups: 'Allow All Groups',
allowAllGroupsHint: 'User can use any non-exclusive group',
@@ -529,8 +534,10 @@ export default {
depositAmount: 'Deposit Amount',
withdrawAmount: 'Withdraw Amount',
currentBalance: 'Current Balance',
depositNotesPlaceholder: 'e.g., New user registration bonus, promotional credit, compensation, etc.',
withdrawNotesPlaceholder: 'e.g., Service issue refund, incorrect charge reversal, account closure refund, etc.',
depositNotesPlaceholder:
'e.g., New user registration bonus, promotional credit, compensation, etc.',
withdrawNotesPlaceholder:
'e.g., Service issue refund, incorrect charge reversal, account closure refund, etc.',
notesOptional: 'Notes are optional but helpful for record keeping',
amountHint: 'Please enter a positive amount',
newBalance: 'New Balance',
@@ -597,12 +604,15 @@ export default {
failedToCreate: 'Failed to create group',
failedToUpdate: 'Failed to update group',
failedToDelete: 'Failed to delete group',
deleteConfirm: "Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.",
deleteConfirmSubscription: "Are you sure you want to delete subscription group '{name}'? This will invalidate all API keys bound to this subscription and delete all related subscription records. This action cannot be undone.",
deleteConfirm:
"Are you sure you want to delete '{name}'? All associated API keys will no longer belong to any group.",
deleteConfirmSubscription:
"Are you sure you want to delete subscription group '{name}'? This will invalidate all API keys bound to this subscription and delete all related subscription records. This action cannot be undone.",
subscription: {
title: 'Subscription Settings',
type: 'Billing Type',
typeHint: 'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
typeHint:
'Standard billing deducts from user balance. Subscription mode uses quota limits instead.',
typeNotEditable: 'Billing type cannot be changed after group creation.',
standard: 'Standard (Balance)',
subscription: 'Subscription (Quota)',
@@ -674,7 +684,8 @@ export default {
failedToAssign: 'Failed to assign subscription',
failedToExtend: 'Failed to extend subscription',
failedToRevoke: 'Failed to revoke subscription',
revokeConfirm: "Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
revokeConfirm:
"Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
},
// Accounts
@@ -682,6 +693,26 @@ export default {
title: 'Account Management',
description: 'Manage AI platform accounts and credentials',
createAccount: 'Create Account',
syncFromCrs: 'Sync from CRS',
syncFromCrsTitle: 'Sync Accounts from CRS',
syncFromCrsDesc:
'Sync accounts from claude-relay-service (CRS) into this system (CRS is called server-to-server).',
crsVersionRequirement: '⚠️ Note: CRS version must be ≥ v1.1.240 to support this feature',
crsBaseUrl: 'CRS Base URL',
crsBaseUrlPlaceholder: 'e.g. http://127.0.0.1:3000',
crsUsername: 'Username',
crsPassword: 'Password',
syncProxies: 'Also sync proxies (match by host/port/auth or create)',
syncNow: 'Sync Now',
syncing: 'Syncing...',
syncMissingFields: 'Please fill base URL, username and password',
syncResult: 'Sync Result',
syncResultSummary: 'Created {created}, updated {updated}, skipped {skipped}, failed {failed}',
syncErrors: 'Errors / Skipped Details',
syncCompleted: 'Sync completed: created {created}, updated {updated}',
syncCompletedWithErrors:
'Sync completed with errors: failed {failed} (created {created}, updated {updated})',
syncFailed: 'Sync failed',
editAccount: 'Edit Account',
deleteAccount: 'Delete Account',
searchAccounts: 'Search accounts...',
@@ -726,6 +757,32 @@ export default {
tokenRefreshed: 'Token refreshed successfully',
accountDeleted: 'Account deleted successfully',
rateLimitCleared: 'Rate limit cleared successfully',
bulkActions: {
selected: '{count} account(s) selected',
selectCurrentPage: 'Select this page',
clear: 'Clear selection',
edit: 'Bulk Edit',
delete: 'Bulk Delete',
},
bulkEdit: {
title: 'Bulk Edit Accounts',
selectionInfo:
'{count} account(s) selected. Only checked or filled fields will be updated; others stay unchanged.',
baseUrlPlaceholder: 'https://api.anthropic.com or https://api.openai.com',
baseUrlNotice: 'Applies to API Key accounts only; leave empty to keep existing value',
submit: 'Update Accounts',
updating: 'Updating...',
success: 'Updated {count} account(s)',
partialSuccess: 'Partially updated: {success} succeeded, {failed} failed',
failed: 'Bulk update failed',
noSelection: 'Please select accounts to edit',
noFieldsSelected: 'Select at least one field to update',
},
bulkDeleteTitle: 'Bulk Delete Accounts',
bulkDeleteConfirm: 'Delete the selected {count} account(s)? This action cannot be undone.',
bulkDeleteSuccess: 'Deleted {count} account(s)',
bulkDeletePartial: 'Partially deleted: {success} succeeded, {failed} failed',
bulkDeleteFailed: 'Bulk delete failed',
resetStatus: 'Reset Status',
statusReset: 'Account status reset successfully',
failedToResetStatus: 'Failed to reset account status',
@@ -753,7 +810,8 @@ export default {
modelWhitelist: 'Model Whitelist',
modelMapping: 'Model Mapping',
selectAllowedModels: 'Select allowed models. Leave empty to support all models.',
mapRequestModels: 'Map request models to actual models. Left is the requested model, right is the actual model sent to API.',
mapRequestModels:
'Map request models to actual models. Left is the requested model, right is the actual model sent to API.',
selectedModels: 'Selected {count} model(s)',
supportsAllModels: '(supports all models)',
requestModel: 'Request model',
@@ -762,14 +820,16 @@ export default {
mappingExists: 'Mapping for {model} already exists',
customErrorCodes: 'Custom Error Codes',
customErrorCodesHint: 'Only stop scheduling for selected error codes',
customErrorCodesWarning: 'Only selected error codes will stop scheduling. Other errors will return 500.',
customErrorCodesWarning:
'Only selected error codes will stop scheduling. Other errors will return 500.',
selectedErrorCodes: 'Selected',
noneSelectedUsesDefault: 'None selected (uses default policy)',
enterErrorCode: 'Enter error code (100-599)',
invalidErrorCode: 'Please enter a valid HTTP error code (100-599)',
errorCodeExists: 'This error code is already selected',
interceptWarmupRequests: 'Intercept Warmup Requests',
interceptWarmupRequestsDesc: 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
interceptWarmupRequestsDesc:
'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
proxy: 'Proxy',
noProxy: 'No Proxy',
concurrency: 'Concurrency',
@@ -792,11 +852,13 @@ export default {
authMethod: 'Authorization Method',
manualAuth: 'Manual Authorization',
cookieAutoAuth: 'Cookie Auto-Auth',
cookieAutoAuthDesc: 'Use claude.ai sessionKey to automatically complete OAuth authorization without manually opening browser.',
cookieAutoAuthDesc:
'Use claude.ai sessionKey to automatically complete OAuth authorization without manually opening browser.',
sessionKey: 'sessionKey',
keysCount: '{count} keys',
batchCreateAccounts: 'Will batch create {count} accounts',
sessionKeyPlaceholder: 'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholder:
'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
howToGetSessionKey: 'How to get sessionKey',
step1: 'Login to <strong>claude.ai</strong> in your browser',
@@ -814,10 +876,13 @@ export default {
generating: 'Generating...',
regenerate: 'Regenerate',
step2OpenUrl: 'Open the URL in your browser and complete authorization',
openUrlDesc: 'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
proxyWarning: '<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
openUrlDesc:
'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
proxyWarning:
'<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
step3EnterCode: 'Enter the Authorization Code',
authCodeDesc: 'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
authCodeDesc:
'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
authCode: 'Authorization Code',
authCodePlaceholder: 'Paste the Authorization Code from Claude page...',
authCodeHint: 'Paste the Authorization Code copied from the Claude page',
@@ -835,13 +900,18 @@ export default {
step1GenerateUrl: 'Click the button below to generate the authorization URL',
generateAuthUrl: 'Generate Auth URL',
step2OpenUrl: 'Open the URL in your browser and complete authorization',
openUrlDesc: 'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
importantNotice: '<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
openUrlDesc:
'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
importantNotice:
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
step3EnterCode: 'Enter Authorization URL or Code',
authCodeDesc: 'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
authCodeDesc:
'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
authCode: 'Authorization URL or Code',
authCodePlaceholder: 'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect',
authCodePlaceholder:
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
authCodeHint:
'You can copy the entire URL or just the code parameter value, the system will auto-detect',
},
},
// Re-Auth Modal
@@ -941,8 +1011,10 @@ export default {
standardAdd: 'Standard Add',
batchAdd: 'Quick Add',
batchInput: 'Proxy List',
batchInputPlaceholder: "Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: "Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
batchInputPlaceholder:
"Enter one proxy per line in the following formats:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint:
"Supports http, https, socks5 protocols. Format: protocol://[user:pass{'@'}]host:port",
parsedCount: '{count} valid',
invalidCount: '{count} invalid',
duplicateCount: '{count} duplicate',
@@ -965,7 +1037,8 @@ export default {
failedToUpdate: 'Failed to update proxy',
failedToDelete: 'Failed to delete proxy',
failedToTest: 'Failed to test proxy',
deleteConfirm: "Are you sure you want to delete '{name}'? Accounts using this proxy will have their proxy removed.",
deleteConfirm:
"Are you sure you want to delete '{name}'? Accounts using this proxy will have their proxy removed.",
},
// Redeem Codes
@@ -994,8 +1067,10 @@ export default {
exportCsv: 'Export CSV',
deleteAllUnused: 'Delete All Unused Codes',
deleteCode: 'Delete Redeem Code',
deleteCodeConfirm: 'Are you sure you want to delete this redeem code? This action cannot be undone.',
deleteAllUnusedConfirm: 'Are you sure you want to delete all unused (active) redeem codes? This action cannot be undone.',
deleteCodeConfirm:
'Are you sure you want to delete this redeem code? This action cannot be undone.',
deleteAllUnusedConfirm:
'Are you sure you want to delete all unused (active) redeem codes? This action cannot be undone.',
deleteAll: 'Delete All',
generateCodesTitle: 'Generate Redeem Codes',
generatedSuccessfully: 'Generated Successfully',
@@ -1075,7 +1150,8 @@ export default {
siteSubtitle: 'Site Subtitle',
siteSubtitleHint: 'Displayed on login and register pages',
apiBaseUrl: 'API Base URL',
apiBaseUrlHint: 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
apiBaseUrlHint:
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
contactInfo: 'Contact Info',
contactInfoPlaceholder: 'e.g., QQ: 123456789',
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',
@@ -1125,7 +1201,8 @@ export default {
create: 'Create Key',
creating: 'Creating...',
regenerateConfirm: 'Are you sure? The current key will be immediately invalidated.',
deleteConfirm: 'Are you sure you want to delete the admin API key? External integrations will stop working.',
deleteConfirm:
'Are you sure you want to delete the admin API key? External integrations will stop working.',
keyGenerated: 'New admin API key generated',
keyDeleted: 'Admin API key deleted',
copyKey: 'Copy Key',
@@ -1191,7 +1268,8 @@ export default {
title: 'My Subscriptions',
description: 'View your subscription plans and usage',
noActiveSubscriptions: 'No Active Subscriptions',
noActiveSubscriptionsDesc: 'You don\'t have any active subscriptions. Contact administrator to get one.',
noActiveSubscriptionsDesc:
"You don't have any active subscriptions. Contact administrator to get one.",
status: {
active: 'Active',
expired: 'Expired',

View File

@@ -620,7 +620,8 @@ export default {
editGroup: '编辑分组',
deleteGroup: '删除分组',
deleteConfirm: "确定要删除分组 '{name}' 吗?所有关联的 API 密钥将不再属于任何分组。",
deleteConfirmSubscription: "确定要删除订阅分组 '{name}' 吗?此操作会让所有绑定此订阅的用户的 API Key 失效,并删除所有相关的订阅记录。此操作无法撤销。",
deleteConfirmSubscription:
"确定要删除订阅分组 '{name}' 吗?此操作会让所有绑定此订阅的用户的 API Key 失效,并删除所有相关的订阅记录。此操作无法撤销。",
columns: {
name: '名称',
platform: '平台',
@@ -780,6 +781,25 @@ export default {
title: '账号管理',
description: '管理 AI 平台账号和 Cookie',
createAccount: '添加账号',
syncFromCrs: '从 CRS 同步',
syncFromCrsTitle: '从 CRS 同步账号',
syncFromCrsDesc:
'将 claude-relay-serviceCRS中的账号同步到当前系统不会在浏览器侧直接请求 CRS。',
crsVersionRequirement: '⚠️ 注意CRS 版本必须 ≥ v1.1.240 才支持此功能',
crsBaseUrl: 'CRS 服务地址',
crsBaseUrlPlaceholder: '例如http://127.0.0.1:3000',
crsUsername: '用户名',
crsPassword: '密码',
syncProxies: '同时同步代理(按 host/port/账号匹配或自动创建)',
syncNow: '开始同步',
syncing: '同步中...',
syncMissingFields: '请填写服务地址、用户名和密码',
syncResult: '同步结果',
syncResultSummary: '创建 {created},更新 {updated},跳过 {skipped},失败 {failed}',
syncErrors: '错误/跳过详情',
syncCompleted: '同步完成:创建 {created},更新 {updated}',
syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated}',
syncFailed: '同步失败',
editAccount: '编辑账号',
deleteAccount: '删除账号',
deleteConfirmMessage: "确定要删除账号 '{name}' 吗?",
@@ -859,6 +879,31 @@ export default {
accountCreatedSuccess: '账号添加成功',
accountUpdatedSuccess: '账号更新成功',
accountDeletedSuccess: '账号删除成功',
bulkActions: {
selected: '已选择 {count} 个账号',
selectCurrentPage: '本页全选',
clear: '清除选择',
edit: '批量编辑账号',
delete: '批量删除',
},
bulkEdit: {
title: '批量编辑账号',
selectionInfo: '已选择 {count} 个账号。只更新您勾选或填写的字段,未勾选的字段保持不变。',
baseUrlPlaceholder: 'https://api.anthropic.com 或 https://api.openai.com',
baseUrlNotice: '仅适用于 API Key 账号,留空则不修改',
submit: '批量更新',
updating: '更新中...',
success: '成功更新 {count} 个账号',
partialSuccess: '部分更新成功:成功 {success} 个,失败 {failed} 个',
failed: '批量更新失败',
noSelection: '请选择要编辑的账号',
noFieldsSelected: '请至少选择一个要更新的字段',
},
bulkDeleteTitle: '批量删除账号',
bulkDeleteConfirm: '确定要删除选中的 {count} 个账号吗?此操作无法撤销。',
bulkDeleteSuccess: '成功删除 {count} 个账号',
bulkDeletePartial: '部分删除成功:成功 {success} 个,失败 {failed} 个',
bulkDeleteFailed: '批量删除失败',
resetStatus: '重置状态',
statusReset: '账号状态已重置',
failedToResetStatus: '重置账号状态失败',
@@ -931,7 +976,8 @@ export default {
sessionKey: 'sessionKey',
keysCount: '{count} 个密钥',
batchCreateAccounts: '将批量创建 {count} 个账号',
sessionKeyPlaceholder: '每行一个 sessionKey例如\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholder:
'每行一个 sessionKey例如\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
howToGetSessionKey: '如何获取 sessionKey',
step1: '在浏览器中登录 <strong>claude.ai</strong>',
@@ -950,7 +996,8 @@ export default {
regenerate: '重新生成',
step2OpenUrl: '在浏览器中打开 URL 并完成授权',
openUrlDesc: '在新标签页中打开授权 URL登录您的 Claude 账号并授权。',
proxyWarning: '<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
proxyWarning:
'<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
step3EnterCode: '输入授权码',
authCodeDesc: '授权完成后,页面会显示一个 <strong>授权码</strong>。复制并粘贴到下方:',
authCode: '授权码',
@@ -971,11 +1018,14 @@ export default {
generateAuthUrl: '生成授权链接',
step2OpenUrl: '在浏览器中打开链接并完成授权',
openUrlDesc: '请在新标签页中打开授权链接,登录您的 OpenAI 账户并授权。',
importantNotice: '<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
importantNotice:
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
step3EnterCode: '输入授权链接或 Code',
authCodeDesc: '授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
authCodeDesc:
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
authCode: '授权链接或 Code',
authCodePlaceholder: '方式1复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2仅复制 code 参数的值',
authCodePlaceholder:
'方式1复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2仅复制 code 参数的值',
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别',
},
},
@@ -1111,7 +1161,8 @@ export default {
standardAdd: '标准添加',
batchAdd: '快捷添加',
batchInput: '代理列表',
batchInputPlaceholder: "每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputPlaceholder:
"每行输入一个代理,支持以下格式:\nsocks5://user:pass{'@'}192.168.1.1:1080\nhttp://192.168.1.1:8080\nhttps://user:pass{'@'}proxy.example.com:443",
batchInputHint: "支持 http、https、socks5 协议,格式:协议://[用户名:密码{'@'}]主机:端口",
parsedCount: '有效 {count} 个',
invalidCount: '无效 {count} 个',

View File

@@ -366,13 +366,24 @@ export interface AccountUsageInfo {
// OpenAI Codex usage snapshot (from response headers)
export interface CodexUsageSnapshot {
codex_primary_used_percent?: number; // Weekly limit usage percentage
codex_primary_reset_after_seconds?: number; // Seconds until weekly reset
codex_primary_window_minutes?: number; // Weekly window in minutes
codex_secondary_used_percent?: number; // 5h limit usage percentage
codex_secondary_reset_after_seconds?: number; // Seconds until 5h reset
codex_secondary_window_minutes?: number; // 5h window in minutes
// Legacy fields (kept for backwards compatibility)
// NOTE: The naming is ambiguous - actual window type is determined by window_minutes value
codex_primary_used_percent?: number; // Usage percentage (check window_minutes for actual window type)
codex_primary_reset_after_seconds?: number; // Seconds until reset
codex_primary_window_minutes?: number; // Window in minutes
codex_secondary_used_percent?: number; // Usage percentage (check window_minutes for actual window type)
codex_secondary_reset_after_seconds?: number; // Seconds until reset
codex_secondary_window_minutes?: number; // Window in minutes
codex_primary_over_secondary_percent?: number; // Overflow ratio
// Canonical fields (normalized by backend, use these preferentially)
codex_5h_used_percent?: number; // 5-hour window usage percentage
codex_5h_reset_after_seconds?: number; // Seconds until 5h window reset
codex_5h_window_minutes?: number; // 5h window in minutes (should be ~300)
codex_7d_used_percent?: number; // 7-day window usage percentage
codex_7d_reset_after_seconds?: number; // Seconds until 7d window reset
codex_7d_window_minutes?: number; // 7d window in minutes (should be ~10080)
codex_usage_updated_at?: string; // Last update timestamp
}

View File

@@ -16,6 +16,15 @@
<path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0l3.181 3.183a8.25 8.25 0 0013.803-3.7M4.031 9.865a8.25 8.25 0 0113.803-3.7l3.181 3.182m0-4.991v4.99" />
</svg>
</button>
<button
@click="showCrsSyncModal = true"
class="btn btn-secondary"
title="从 CRS 同步"
>
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
<path stroke-linecap="round" stroke-linejoin="round" d="M20.25 6.375c0 2.278-3.694 4.125-8.25 4.125S3.75 8.653 3.75 6.375m16.5 0c0-2.278-3.694-4.125-8.25-4.125S3.75 4.097 3.75 6.375m16.5 0v11.25c0 2.278-3.694 4.125-8.25 4.125s-8.25-1.847-8.25-4.125V6.375m16.5 0v3.75m-16.5-3.75v3.75m16.5 0v3.75C20.25 16.153 16.556 18 12 18s-8.25-1.847-8.25-4.125v-3.75m16.5 0c0 2.278-3.694 4.125-8.25 4.125s-8.25-1.847-8.25-4.125" />
</svg>
</button>
<button
@click="showCreateModal = true"
class="btn btn-primary"
@@ -66,9 +75,63 @@
</div>
</div>
<!-- Bulk Actions Bar -->
<div v-if="selectedAccountIds.length > 0" class="card bg-primary-50 dark:bg-primary-900/20 border-primary-200 dark:border-primary-800 px-4 py-3">
<div class="flex flex-wrap items-center justify-between gap-3">
<div class="flex flex-wrap items-center gap-2">
<span class="text-sm font-medium text-primary-900 dark:text-primary-100">
{{ t('admin.accounts.bulkActions.selected', { count: selectedAccountIds.length }) }}
</span>
<button
@click="selectCurrentPageAccounts"
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
>
{{ t('admin.accounts.bulkActions.selectCurrentPage') }}
</button>
<span class="text-gray-300 dark:text-primary-800"></span>
<button
@click="selectedAccountIds = []"
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
>
{{ t('admin.accounts.bulkActions.clear') }}
</button>
</div>
<div class="flex items-center gap-2">
<button
@click="handleBulkDelete"
class="btn btn-danger btn-sm"
>
<svg class="w-4 h-4 mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
<path stroke-linecap="round" stroke-linejoin="round"
d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" />
</svg>
{{ t('admin.accounts.bulkActions.delete') }}
</button>
<button
@click="showBulkEditModal = true"
class="btn btn-primary btn-sm"
>
<svg class="w-4 h-4 mr-1.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
<path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" />
</svg>
{{ t('admin.accounts.bulkActions.edit') }}
</button>
</div>
</div>
</div>
<!-- Accounts Table -->
<div class="card overflow-hidden">
<DataTable :columns="columns" :data="accounts" :loading="loading">
<template #cell-select="{ row }">
<input
type="checkbox"
:checked="selectedAccountIds.includes(row.id)"
@change="toggleAccountSelection(row.id)"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</template>
<template #cell-name="{ value }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
</template>
@@ -315,6 +378,32 @@
@confirm="confirmDelete"
@cancel="showDeleteDialog = false"
/>
<ConfirmDialog
:show="showBulkDeleteDialog"
:title="t('admin.accounts.bulkDeleteTitle')"
:message="t('admin.accounts.bulkDeleteConfirm', { count: selectedAccountIds.length })"
:confirm-text="t('common.delete')"
:cancel-text="t('common.cancel')"
:danger="true"
@confirm="confirmBulkDelete"
@cancel="showBulkDeleteDialog = false"
/>
<SyncFromCrsModal
:show="showCrsSyncModal"
@close="showCrsSyncModal = false"
@synced="handleCrsSynced"
/>
<!-- Bulk Edit Account Modal -->
<BulkEditAccountModal
:show="showBulkEditModal"
:account-ids="selectedAccountIds"
:proxies="proxies"
:groups="groups"
@close="showBulkEditModal = false"
@updated="handleBulkUpdated"
/>
</AppLayout>
</template>
@@ -331,7 +420,7 @@ import Pagination from '@/components/common/Pagination.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Select from '@/components/common/Select.vue'
import { CreateAccountModal, EditAccountModal, ReAuthAccountModal, AccountStatsModal } from '@/components/account'
import { CreateAccountModal, EditAccountModal, BulkEditAccountModal, ReAuthAccountModal, AccountStatsModal, SyncFromCrsModal } from '@/components/account'
import AccountStatusIndicator from '@/components/account/AccountStatusIndicator.vue'
import AccountUsageCell from '@/components/account/AccountUsageCell.vue'
import AccountTodayStatsCell from '@/components/account/AccountTodayStatsCell.vue'
@@ -345,6 +434,7 @@ const appStore = useAppStore()
// Table columns
const columns = computed<Column[]>(() => [
{ key: 'select', label: '', sortable: false },
{ key: 'name', label: t('admin.accounts.columns.name'), sortable: true },
{ key: 'platform_type', label: t('admin.accounts.columns.platformType'), sortable: false },
{ key: 'concurrency', label: t('admin.accounts.columns.concurrencyStatus'), sortable: false },
@@ -402,14 +492,26 @@ const showCreateModal = ref(false)
const showEditModal = ref(false)
const showReAuthModal = ref(false)
const showDeleteDialog = ref(false)
const showBulkDeleteDialog = ref(false)
const showTestModal = ref(false)
const showStatsModal = ref(false)
const showCrsSyncModal = ref(false)
const showBulkEditModal = ref(false)
const editingAccount = ref<Account | null>(null)
const reAuthAccount = ref<Account | null>(null)
const deletingAccount = ref<Account | null>(null)
const testingAccount = ref<Account | null>(null)
const statsAccount = ref<Account | null>(null)
const togglingSchedulable = ref<number | null>(null)
const bulkDeleting = ref(false)
// Bulk selection
const selectedAccountIds = ref<number[]>([])
const selectCurrentPageAccounts = () => {
const pageIds = accounts.value.map(account => account.id)
const merged = new Set([...selectedAccountIds.value, ...pageIds])
selectedAccountIds.value = Array.from(merged)
}
// Rate limit / Overload helpers
const isRateLimited = (account: Account): boolean => {
@@ -480,6 +582,11 @@ const handlePageChange = (page: number) => {
loadAccounts()
}
const handleCrsSynced = () => {
showCrsSyncModal.value = false
loadAccounts()
}
// Edit modal
const handleEdit = (account: Account) => {
editingAccount.value = account
@@ -535,6 +642,38 @@ const confirmDelete = async () => {
}
}
const handleBulkDelete = () => {
if (selectedAccountIds.value.length === 0) return
showBulkDeleteDialog.value = true
}
const confirmBulkDelete = async () => {
if (bulkDeleting.value || selectedAccountIds.value.length === 0) return
bulkDeleting.value = true
const ids = [...selectedAccountIds.value]
try {
const results = await Promise.allSettled(ids.map(id => adminAPI.accounts.delete(id)))
const success = results.filter(result => result.status === 'fulfilled').length
const failed = results.length - success
if (failed === 0) {
appStore.showSuccess(t('admin.accounts.bulkDeleteSuccess', { count: success }))
} else {
appStore.showError(t('admin.accounts.bulkDeletePartial', { success, failed }))
}
showBulkDeleteDialog.value = false
selectedAccountIds.value = []
loadAccounts()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.bulkDeleteFailed'))
console.error('Error deleting accounts:', error)
} finally {
bulkDeleting.value = false
}
}
// Clear rate limit
const handleClearRateLimit = async (account: Account) => {
try {
@@ -608,6 +747,23 @@ const closeStatsModal = () => {
statsAccount.value = null
}
// Bulk selection toggle
const toggleAccountSelection = (accountId: number) => {
const index = selectedAccountIds.value.indexOf(accountId)
if (index === -1) {
selectedAccountIds.value.push(accountId)
} else {
selectedAccountIds.value.splice(index, 1)
}
}
// Bulk update handler
const handleBulkUpdated = () => {
showBulkEditModal.value = false
selectedAccountIds.value = []
loadAccounts()
}
// Initialize
onMounted(() => {
loadAccounts()