mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 00:40:22 +08:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d30ceae8d | ||
|
|
60f6ed6bf6 | ||
|
|
4a2f7d4a99 | ||
|
|
c19a393be9 | ||
|
|
938ffb002e | ||
|
|
372a01290b | ||
|
|
8b163ca49b | ||
|
|
d23810dc53 | ||
|
|
62ed5422dd | ||
|
|
2e76302af7 | ||
|
|
6553828008 | ||
|
|
adcb7bf00e | ||
|
|
876e85e7ad | ||
|
|
2e7818d688 | ||
|
|
836c4dda2b | ||
|
|
e65e9587b4 | ||
|
|
aaadd6ed04 | ||
|
|
870b21916c | ||
|
|
fb119f9a67 | ||
|
|
ad54795a24 | ||
|
|
0abe322cca | ||
|
|
b071511676 | ||
|
|
7d9a757a26 | ||
|
|
bbf4024dc7 | ||
|
|
5831eb8a6a | ||
|
|
61838cdb3d | ||
|
|
50dba656fd | ||
|
|
0e2821456c | ||
|
|
f25ac3aff5 | ||
|
|
f6341b7f2b | ||
|
|
4e257512b9 | ||
|
|
e53b34f321 | ||
|
|
12ddae0184 | ||
|
|
7b9c3f165e | ||
|
|
0b8e84f942 | ||
|
|
d9e27df9af | ||
|
|
f0fabf89a1 | ||
|
|
5bbfbcdae9 | ||
|
|
eb55947ec4 | ||
|
|
5f7e5184eb | ||
|
|
008a111268 | ||
|
|
fda753278c | ||
|
|
6c469b42ed |
94
.github/workflows/release.yml
vendored
94
.github/workflows/release.yml
vendored
@@ -85,6 +85,19 @@ jobs:
|
||||
go-version: '1.24'
|
||||
cache-dependency-path: backend/go.sum
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Fetch tags with annotations
|
||||
run: |
|
||||
# 确保获取完整的 annotated tag 信息
|
||||
@@ -117,87 +130,16 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
||||
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
# ===========================================================================
|
||||
# Docker Build and Push
|
||||
# ===========================================================================
|
||||
docker:
|
||||
needs: [update-version, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: version-file
|
||||
path: backend/cmd/server/
|
||||
|
||||
- name: Download frontend artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: backend/internal/web/dist/
|
||||
|
||||
# Extract version from tag
|
||||
- name: Extract version
|
||||
id: version
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Version: $VERSION"
|
||||
|
||||
# Set up Docker Buildx for multi-platform builds
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Login to DockerHub
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Extract metadata for Docker
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
weishaw/sub2api
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
# Build and push Docker image
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
COMMIT=${{ github.sha }}
|
||||
DATE=${{ github.event.head_commit.timestamp }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Update DockerHub description (optional)
|
||||
# Update DockerHub description
|
||||
- name: Update DockerHub description
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
repository: weishaw/sub2api
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -28,6 +28,7 @@ node_modules/
|
||||
frontend/node_modules/
|
||||
frontend/dist/
|
||||
*.local
|
||||
*.tsbuildinfo
|
||||
|
||||
# 日志
|
||||
npm-debug.log*
|
||||
@@ -91,6 +92,13 @@ backend/internal/web/dist/*
|
||||
# 后端运行时缓存数据
|
||||
backend/data/
|
||||
|
||||
# ===================
|
||||
# 本地配置文件(包含敏感信息)
|
||||
# ===================
|
||||
backend/config.yaml
|
||||
deploy/config.yaml
|
||||
backend/.installed
|
||||
|
||||
# ===================
|
||||
# 其他
|
||||
# ===================
|
||||
|
||||
@@ -52,10 +52,58 @@ changelog:
|
||||
# 禁用自动 changelog,完全使用 tag 消息
|
||||
disable: true
|
||||
|
||||
# Docker images
|
||||
dockers:
|
||||
- id: amd64
|
||||
goos: linux
|
||||
goarch: amd64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
- id: arm64
|
||||
goos: linux
|
||||
goarch: arm64
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
|
||||
# Docker manifests for multi-arch support
|
||||
docker_manifests:
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: Wei-Shaw
|
||||
name: sub2api
|
||||
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
|
||||
name: "{{ .Env.GITHUB_REPO_NAME }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "Sub2API {{.Version}}"
|
||||
@@ -73,7 +121,7 @@ release:
|
||||
|
||||
**One-line install (Linux):**
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
|
||||
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
|
||||
```
|
||||
|
||||
**Manual download:**
|
||||
@@ -81,5 +129,5 @@ release:
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
|
||||
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
|
||||
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
|
||||
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)
|
||||
|
||||
40
Dockerfile.goreleaser
Normal file
40
Dockerfile.goreleaser
Normal file
@@ -0,0 +1,40 @@
|
||||
# =============================================================================
|
||||
# Sub2API Dockerfile for GoReleaser
|
||||
# =============================================================================
|
||||
# This Dockerfile is used by GoReleaser to build Docker images.
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy pre-built binary from GoReleaser
|
||||
COPY sub2api /app/sub2api
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||
|
||||
USER sub2api
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sub2api"]
|
||||
@@ -17,10 +17,17 @@ linters:
|
||||
service-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- internal/service/**
|
||||
- "**/internal/service/**"
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
desc: "service must not import repository"
|
||||
handler-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- "**/internal/handler/**"
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
desc: "handler must not import repository"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/setup"
|
||||
"sub2api/internal/web"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"context"
|
||||
"log"
|
||||
@@ -85,6 +85,14 @@ func provideCleanup(
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -8,17 +8,17 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
@@ -67,7 +67,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardHandler := admin.NewDashboardHandler(usageLogRepository)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
@@ -76,13 +77,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
@@ -92,8 +99,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
@@ -103,43 +110,45 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OpenAIGateway: openAIGatewayService,
|
||||
OAuth: oAuthService,
|
||||
OpenAIOAuth: openAIOAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
@@ -201,6 +210,14 @@ func provideCleanup(
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
mode: "debug" # debug/release
|
||||
|
||||
database:
|
||||
host: "127.0.0.1"
|
||||
port: 5432
|
||||
user: "postgres"
|
||||
password: "XZeRr7nkjHWhm8fw"
|
||||
dbname: "sub2api"
|
||||
sslmode: "disable"
|
||||
|
||||
redis:
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
jwt:
|
||||
secret: "your-secret-key-change-in-production"
|
||||
expire_hour: 24
|
||||
|
||||
default:
|
||||
admin_email: "admin@sub2api.com"
|
||||
admin_password: "admin123"
|
||||
user_concurrency: 5
|
||||
user_balance: 0
|
||||
api_key_prefix: "sk-"
|
||||
rate_multiplier: 1.0
|
||||
|
||||
# Timezone configuration (similar to PHP's date_default_timezone_set)
|
||||
# This affects ALL time operations:
|
||||
# - Database timestamps
|
||||
# - Usage statistics "today" boundary
|
||||
# - Subscription expiry times
|
||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
||||
timezone: "Asia/Shanghai"
|
||||
@@ -1,4 +1,4 @@
|
||||
module sub2api
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.24.0
|
||||
|
||||
@@ -8,10 +8,13 @@ require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/spf13/viper v1.18.2
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/term v0.37.0
|
||||
@@ -35,7 +38,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/google/wire v0.7.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -64,6 +66,8 @@ require (
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
|
||||
@@ -139,6 +139,15 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
|
||||
@@ -3,9 +3,12 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -26,19 +29,34 @@ func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||
type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
|
||||
func NewAccountHandler(
|
||||
adminService service.AdminService,
|
||||
oauthService *service.OAuthService,
|
||||
openaiOAuthService *service.OpenAIOAuthService,
|
||||
rateLimitService *service.RateLimitService,
|
||||
accountUsageService *service.AccountUsageService,
|
||||
accountTestService *service.AccountTestService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +87,25 @@ type UpdateAccountRequest struct {
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*model.Account
|
||||
CurrentConcurrency int `json:"current_concurrency"`
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
// GET /api/v1/admin/accounts
|
||||
func (h *AccountHandler) List(c *gin.Context) {
|
||||
@@ -84,7 +121,28 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
// Get current concurrency counts for all accounts
|
||||
accountIDs := make([]int64, len(accounts))
|
||||
for i, acc := range accounts {
|
||||
accountIDs[i] = acc.ID
|
||||
}
|
||||
|
||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
// Log error but don't fail the request, just use 0 for all
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
result[i] = AccountWithConcurrency{
|
||||
Account: &accounts[i],
|
||||
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
@@ -190,6 +248,13 @@ type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
SyncProxies *bool `json:"sync_proxies"`
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
// POST /api/v1/admin/accounts/:id/test
|
||||
func (h *AccountHandler) Test(c *gin.Context) {
|
||||
@@ -210,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
// POST /api/v1/admin/accounts/sync/crs
|
||||
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
var req SyncFromCRSRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Default to syncing proxies (can be disabled by explicitly setting false)
|
||||
syncProxies := true
|
||||
if req.SyncProxies != nil {
|
||||
syncProxies = *req.SyncProxies
|
||||
}
|
||||
|
||||
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
SyncProxies: syncProxies,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Sync failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
@@ -232,26 +326,47 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
var newCredentials map[string]any
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
@@ -273,15 +388,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
_ = accountID
|
||||
response.Success(c, gin.H{
|
||||
"total_requests": 0,
|
||||
"successful_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"total_tokens": 0,
|
||||
"average_response_time": 0,
|
||||
})
|
||||
// Parse days parameter (default 30)
|
||||
days := 30
|
||||
if daysStr := c.Query("days"); daysStr != "" {
|
||||
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate time range
|
||||
now := timezone.Now()
|
||||
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
|
||||
|
||||
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get account stats: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// ClearError handles clearing account error
|
||||
@@ -321,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUpdateCredentialsRequest represents batch credentials update request
|
||||
type BatchUpdateCredentialsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// BatchUpdateCredentials handles batch updating credentials fields
|
||||
// POST /api/v1/admin/accounts/batch-update-credentials
|
||||
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
var req BatchUpdateCredentialsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate value type based on field
|
||||
if req.Field == "intercept_warmup_requests" {
|
||||
// Must be boolean
|
||||
if _, ok := req.Value.(bool); !ok {
|
||||
response.BadRequest(c, "intercept_warmup_requests must be boolean")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// account_uuid and org_uuid can be string or null
|
||||
if req.Value != nil {
|
||||
if _, ok := req.Value.(string); !ok {
|
||||
response.BadRequest(c, req.Field+" must be string or null")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
|
||||
// Update account
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
|
||||
// POST /api/v1/admin/accounts/bulk-update
|
||||
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
var req BulkUpdateAccountsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hasUpdates := req.Name != "" ||
|
||||
req.ProxyID != nil ||
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
|
||||
if !hasUpdates {
|
||||
response.BadRequest(c, "No updates provided")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
@@ -563,6 +819,46 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// Return mapped models
|
||||
var models []openai.Model
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range openai.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, openai.Model{
|
||||
ID: requestedModel,
|
||||
Object: "model",
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
response.Success(c, models)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, claude.DefaultModels)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strconv"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -12,15 +12,15 @@ import (
|
||||
|
||||
// DashboardHandler handles admin dashboard statistics
|
||||
type DashboardHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
dashboardService *service.DashboardService
|
||||
startTime time.Time // Server start time for uptime calculation
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new admin dashboard handler
|
||||
func NewDashboardHandler(usageRepo *repository.UsageLogRepository) *DashboardHandler {
|
||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||
return &DashboardHandler{
|
||||
usageRepo: usageRepo,
|
||||
startTime: time.Now(),
|
||||
dashboardService: dashboardService,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
// GetStats handles getting dashboard statistics
|
||||
// GET /api/v1/admin/dashboard/stats
|
||||
func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
@@ -107,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
// 系统运行统计
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"uptime": uptime,
|
||||
|
||||
// 性能指标
|
||||
"rpm": stats.Rpm,
|
||||
"tpm": stats.Tpm,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -142,7 +146,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -175,7 +179,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
@@ -200,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
@@ -226,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
@@ -259,7 +263,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
@@ -287,7 +291,7 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -3,9 +3,9 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
228
backend/internal/handler/admin/openai_oauth_handler.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
|
||||
type OpenAIOAuthHandler struct {
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
|
||||
type OpenAIGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates OpenAI OAuth authorization URL
|
||||
// POST /api/v1/admin/openai/generate-auth-url
|
||||
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req OpenAIGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges OpenAI authorization code for tokens
|
||||
// POST /api/v1/admin/openai/exchange-code
|
||||
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req OpenAIExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to refresh token: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create account: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -104,6 +105,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/sysutil"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -4,35 +4,32 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageRepo *repository.UsageLogRepository,
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageRepo: usageRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
usageService: usageService,
|
||||
adminService: adminService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,14 +81,14 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := repository.UsageLogFilters{
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
return
|
||||
@@ -179,7 +176,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
return
|
||||
@@ -237,7 +234,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
||||
return
|
||||
|
||||
@@ -3,8 +3,8 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Password string `json:"password" binding:"omitempty,min=6"`
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
type UpdateBalanceRequest struct {
|
||||
Balance float64 `json:"balance" binding:"required"`
|
||||
Balance float64 `json:"balance" binding:"required,gt=0"`
|
||||
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all users with pagination
|
||||
@@ -94,6 +101,9 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
@@ -125,6 +135,9 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
@@ -171,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||
return
|
||||
|
||||
@@ -3,10 +3,10 @@ package handler
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -10,27 +10,21 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum wait time for concurrency slot
|
||||
maxConcurrencyWait = 60 * time.Second
|
||||
// Ping interval during wait
|
||||
pingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
userService *service.UserService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -38,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
userService: userService,
|
||||
concurrencyService: concurrencyService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +83,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -98,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -139,7 +133,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -173,135 +167,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
|
||||
// For streaming requests, sends ping events during the wait
|
||||
// streamStarted is updated if streaming response has begun
|
||||
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// concurrencyError represents a concurrency limit error with context
|
||||
type concurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *concurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
|
||||
// Note: For streaming requests, we send ping to keep the connection alive.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
|
||||
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// For streaming requests, set up SSE headers for ping
|
||||
var flusher http.Flusher
|
||||
if isStream {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &concurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
// Send ping for streaming requests to keep connection alive
|
||||
if isStream && flusher != nil {
|
||||
// Set headers on first ping (lazy initialization)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
// Returns different model lists based on the API key's group platform
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware.GetApiKeyFromContext(c)
|
||||
|
||||
// Return OpenAI models for OpenAI platform groups
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Default: Claude models
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": claude.DefaultModels,
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
180
backend/internal/handler/gateway_helper.go
Normal file
180
backend/internal/handler/gateway_helper.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval is the interval for sending ping events during slot wait
|
||||
pingInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
type SSEPingFormat string
|
||||
|
||||
const (
|
||||
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
||||
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
||||
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
||||
SSEPingFormatNone SSEPingFormat = ""
|
||||
)
|
||||
|
||||
// ConcurrencyError represents a concurrency limit error with context
|
||||
type ConcurrencyError struct {
|
||||
SlotType string
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
func (e *ConcurrencyError) Error() string {
|
||||
if e.IsTimeout {
|
||||
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
||||
}
|
||||
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
||||
}
|
||||
|
||||
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
||||
type ConcurrencyHelper struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
pingFormat SSEPingFormat
|
||||
}
|
||||
|
||||
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
||||
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
||||
return &ConcurrencyHelper{
|
||||
concurrencyService: concurrencyService,
|
||||
pingFormat: pingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait count for a user
|
||||
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Only create ping ticker if ping is needed
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &ConcurrencyError{
|
||||
SlotType: slotType,
|
||||
IsTimeout: true,
|
||||
}
|
||||
|
||||
case <-pingCh:
|
||||
// Send ping to keep connection alive
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-pollTicker.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
@@ -21,15 +22,16 @@ type AdminHandlers struct {
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
Setting *SettingHandler
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
209
backend/internal/handler/openai_gateway_handler.go
Normal file
209
backend/internal/handler/openai_gateway_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// For non-Codex CLI requests, set default instructions
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
reqBody["instructions"] = openai.DefaultInstructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Normal case: return JSON response with proper status code
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -4,12 +4,11 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,15 +16,13 @@ import (
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
usageRepo: usageRepo,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
@@ -260,7 +257,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get dashboard statistics")
|
||||
return
|
||||
@@ -287,7 +284,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage trend")
|
||||
return
|
||||
@@ -318,7 +315,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get model statistics")
|
||||
return
|
||||
@@ -387,7 +384,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Wechat *string `json:"wechat"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
@@ -47,6 +53,9 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, userData)
|
||||
}
|
||||
|
||||
@@ -83,3 +92,40 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to update profile: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, updatedUser)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
@@ -56,18 +58,20 @@ func ProvideHandlers(
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +85,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -89,6 +94,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
@@ -3,9 +3,9 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"log"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -2,9 +2,9 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -68,7 +68,8 @@ type Account struct {
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
Groups []*Group `gorm:"-" json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func (Account) TableName() string {
|
||||
@@ -277,3 +278,138 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// =============== OpenAI 相关方法 ===============
|
||||
|
||||
// IsOpenAI 检查是否为 OpenAI 平台账号
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
// IsAnthropic 检查是否为 Anthropic 平台账号
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
|
||||
// 对于 API Key 类型账号,从 credentials 中获取 base_url
|
||||
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com" // OpenAI 默认 API URL
|
||||
}
|
||||
|
||||
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
|
||||
// 返回空字符串表示透传原始 User-Agent
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试解析时间
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
// 尝试解析为 Unix 时间戳
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
t = time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false // 没有过期时间信息,假设未过期
|
||||
}
|
||||
// 提前 60 秒认为过期,便于刷新
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type RedeemCode struct {
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
|
||||
@@ -42,6 +42,7 @@ const (
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
@@ -80,6 +81,7 @@ type SystemSettings struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -97,5 +99,6 @@ type PublicSettings struct {
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
Username string `gorm:"size:100;default:''" json:"username"`
|
||||
Wechat string `gorm:"size:100;default:''" json:"wechat"`
|
||||
Notes string `gorm:"type:text;default:''" json:"notes"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
@@ -22,7 +25,8 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
|
||||
@@ -43,18 +43,25 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
42
backend/internal/pkg/openai/constants.go
Normal file
42
backend/internal/pkg/openai/constants.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// Model represents an OpenAI model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
//
|
||||
//go:embed instructions.txt
|
||||
var DefaultInstructions string
|
||||
118
backend/internal/pkg/openai/instructions.txt
Normal file
118
backend/internal/pkg/openai/instructions.txt
Normal file
@@ -0,0 +1,118 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No \"save/copy this file\" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no \"above/below\"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
|
||||
|
||||
366
backend/internal/pkg/openai/oauth.go
Normal file
366
backend/internal/pkg/openai/oauth.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// Default redirect URI (can be customized)
|
||||
DefaultRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// Scopes
|
||||
DefaultScopes = "openid profile email offline_access"
|
||||
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
|
||||
RefreshScopes = "openid profile email"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore manages OAuth sessions in memory
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
// Get retrieves a session
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
// Delete removes a session
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
|
||||
// OpenAI uses hex encoding instead of base64url
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
// Uses base64url encoding as per RFC 7636
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OpenAI OAuth
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IDTokenClaims represents the claims from OpenAI ID Token
|
||||
type IDTokenClaims struct {
|
||||
// Standard claims
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Iss string `json:"iss"`
|
||||
Aud []string `json:"aud"` // OpenAI returns aud as an array
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
|
||||
// OpenAI specific claims (nested under https://api.openai.com/auth)
|
||||
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIAuthClaims represents the OpenAI specific auth claims
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
// OrganizationClaim represents an organization in the ID Token
|
||||
type OrganizationClaim struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request for OpenAI
|
||||
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
Scope: RefreshScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// ToFormData converts TokenRequest to URL-encoded form data
|
||||
func (r *TokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("code", r.Code)
|
||||
params.Set("redirect_uri", r.RedirectURI)
|
||||
params.Set("code_verifier", r.CodeVerifier)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ToFormData converts RefreshTokenRequest to URL-encoded form data
|
||||
func (r *RefreshTokenRequest) ToFormData() string {
|
||||
params := url.Values{}
|
||||
params.Set("grant_type", r.GrantType)
|
||||
params.Set("client_id", r.ClientID)
|
||||
params.Set("refresh_token", r.RefreshToken)
|
||||
params.Set("scope", r.Scope)
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if necessary
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try standard encoding
|
||||
decoded, err = base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var claims IDTokenClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
}
|
||||
|
||||
// GetUserInfo extracts user info from ID Token claims
|
||||
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
info := &UserInfo{
|
||||
Email: c.Email,
|
||||
}
|
||||
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
// Get default organization ID
|
||||
for _, org := range c.OpenAIAuth.Organizations {
|
||||
if org.IsDefault {
|
||||
info.OrganizationID = org.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no default, use first org
|
||||
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
|
||||
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
18
backend/internal/pkg/openai/request.go
Normal file
18
backend/internal/pkg/openai/request.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package openai
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_vscode/",
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
209
backend/internal/pkg/usagestats/usage_log_types.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 性能统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
}
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary struct {
|
||||
Days int `json:"days"`
|
||||
ActualDaysUsed int `json:"actual_days_used"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||
Today *struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
} `json:"today"`
|
||||
HighestCostDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
} `json:"highest_cost_day"`
|
||||
HighestRequestDay *struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"highest_request_day"`
|
||||
}
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
}
|
||||
@@ -2,11 +2,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type AccountRepository struct {
|
||||
@@ -23,14 +26,34 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
|
||||
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 填充 GroupIDs 虚拟字段
|
||||
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
|
||||
for _, ag := range account.AccountGroups {
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
account.Groups = append(account.Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
if crsAccountID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -78,15 +101,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充每个 Account 的 GroupIDs 虚拟字段
|
||||
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
|
||||
for i := range accounts {
|
||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
|
||||
for _, ag := range accounts[i].AccountGroups {
|
||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +249,38 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ?", platform).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.platform = ?", platform).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
@@ -267,3 +326,75 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
|
||||
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||
// It merges the updates into existing Extra data without overwriting other fields
|
||||
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current account to preserve existing Extra data
|
||||
var account model.Account
|
||||
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Extra if nil
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(model.JSONB)
|
||||
}
|
||||
|
||||
// Merge updates into existing Extra
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
|
||||
// Save updated Extra
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("extra", account.Extra).Error
|
||||
}
|
||||
|
||||
// BulkUpdate updates multiple accounts with the provided fields.
|
||||
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
updateMap := map[string]any{}
|
||||
|
||||
if updates.Name != nil {
|
||||
updateMap["name"] = *updates.Name
|
||||
}
|
||||
if updates.ProxyID != nil {
|
||||
updateMap["proxy_id"] = updates.ProxyID
|
||||
}
|
||||
if updates.Concurrency != nil {
|
||||
updateMap["concurrency"] = *updates.Concurrency
|
||||
}
|
||||
if updates.Priority != nil {
|
||||
updateMap["priority"] = *updates.Priority
|
||||
}
|
||||
if updates.Status != nil {
|
||||
updateMap["status"] = *updates.Status
|
||||
}
|
||||
if len(updates.Credentials) > 0 {
|
||||
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
|
||||
}
|
||||
if len(updates.Extra) > 0 {
|
||||
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
|
||||
}
|
||||
|
||||
if len(updateMap) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Where("id IN ?", ids).
|
||||
Clauses(clause.Returning{}).
|
||||
Updates(updateMap)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
@@ -139,20 +140,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if len(code) > 0 {
|
||||
parts := make([]string, 0, 2)
|
||||
for i, part := range []rune(code) {
|
||||
if part == '#' {
|
||||
authCode = code[:i]
|
||||
codeState = code[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
authCode = code
|
||||
}
|
||||
if idx := strings.Index(code, "#"); idx != -1 {
|
||||
authCode = code[:idx]
|
||||
codeState = code[idx+1:]
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUsageService struct{}
|
||||
|
||||
@@ -5,60 +5,95 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
concurrencyTTL = 5 * time.Minute
|
||||
// Key prefixes for independent slot keys
|
||||
// Format: concurrency:account:{accountID}:{requestID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// Format: concurrency:user:{userID}:{requestID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// Wait queue keeps counter format: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
|
||||
// Slot TTL - each slot expires independently
|
||||
slotTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
|
||||
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
|
||||
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL in seconds
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
local pattern = KEYS[1]
|
||||
local slotKey = KEYS[2]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
|
||||
-- Count existing slots using SCAN
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
|
||||
-- Check if we can acquire a slot
|
||||
if count < maxConcurrency then
|
||||
redis.call('SET', slotKey, '1', 'EX', ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
// getCountScript counts slots using SCAN
|
||||
// KEYS[1] = pattern for SCAN
|
||||
getCountScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
return count
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
@@ -76,49 +111,86 @@ func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
|
||||
}
|
||||
|
||||
func accountSlotPattern(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
|
||||
}
|
||||
|
||||
func userSlotPattern(userID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
pattern := accountSlotPattern(accountID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
pattern := userSlotPattern(userID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -126,7 +198,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
|
||||
@@ -2,8 +2,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -5,16 +5,19 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type claudeUpstreamService struct {
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||
type httpUpstreamService struct {
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
@@ -27,13 +30,13 @@ func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &claudeUpstreamService{
|
||||
return &httpUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
@@ -41,7 +44,7 @@ func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Re
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
92
backend/internal/repository/openai_oauth_service.go
Normal file
92
backend/internal/repository/openai_oauth_service.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type openaiOAuthService struct{}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{}
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(openai.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -2,7 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -19,6 +19,30 @@ func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
|
||||
return &UsageLogRepository{db: db}
|
||||
}
|
||||
|
||||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||||
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
|
||||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||||
var perfStats struct {
|
||||
RequestCount int64 `gorm:"column:request_count"`
|
||||
TokenCount int64 `gorm:"column:token_count"`
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
|
||||
`).
|
||||
Where("created_at >= ?", fiveMinutesAgo)
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
db.Scan(&perfStats)
|
||||
// 返回5分钟平均值
|
||||
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
@@ -113,46 +137,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
}
|
||||
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats struct {
|
||||
// 用户统计
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
|
||||
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
|
||||
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
|
||||
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 系统运行统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||
}
|
||||
type DashboardStats = usagestats.DashboardStats
|
||||
|
||||
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
@@ -269,6 +254,9 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
stats.TodayCost = todayStats.TodayCost
|
||||
stats.TodayActualCost = todayStats.TodayActualCost
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
||||
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, 0)
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
@@ -398,47 +386,16 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
type TrendDataPoint = usagestats.TrendDataPoint
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
type ModelStat = usagestats.ModelStat
|
||||
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
||||
@@ -531,34 +488,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||
|
||||
// 今日 Token 使用统计
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayInputTokens int64 `json:"today_input_tokens"`
|
||||
TodayOutputTokens int64 `json:"today_output_tokens"`
|
||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||
|
||||
// 性能统计
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
type UserDashboardStats = usagestats.UserDashboardStats
|
||||
|
||||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||||
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||||
@@ -641,6 +571,9 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
stats.TodayCost = todayStats.TodayCost
|
||||
stats.TodayActualCost = todayStats.TodayActualCost
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求)
|
||||
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, userID)
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
@@ -705,12 +638,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
type UsageLogFilters = usagestats.UsageLogFilters
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
@@ -758,23 +686,10 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
type UsageStats = usagestats.UsageStats
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
@@ -834,11 +749,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
@@ -937,7 +848,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
@@ -958,6 +869,9 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
db = db.Where("account_id = ?", accountID)
|
||||
}
|
||||
|
||||
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
|
||||
if err != nil {
|
||||
@@ -1007,3 +921,169 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||||
|
||||
// AccountUsageSummary represents summary statistics for an account
|
||||
type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
if daysCount <= 0 {
|
||||
daysCount = 30
|
||||
}
|
||||
|
||||
// Get daily history
|
||||
var historyResults []struct {
|
||||
Date string `gorm:"column:date"`
|
||||
Requests int64 `gorm:"column:requests"`
|
||||
Tokens int64 `gorm:"column:tokens"`
|
||||
Cost float64 `gorm:"column:cost"`
|
||||
ActualCost float64 `gorm:"column:actual_cost"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Group("date").
|
||||
Order("date ASC").
|
||||
Scan(&historyResults).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build history with labels
|
||||
history := make([]AccountUsageHistory, 0, len(historyResults))
|
||||
for _, h := range historyResults {
|
||||
// Parse date to get label (MM/DD)
|
||||
t, _ := time.Parse("2006-01-02", h.Date)
|
||||
label := t.Format("01/02")
|
||||
history = append(history, AccountUsageHistory{
|
||||
Date: h.Date,
|
||||
Label: label,
|
||||
Requests: h.Requests,
|
||||
Tokens: h.Tokens,
|
||||
Cost: h.Cost,
|
||||
ActualCost: h.ActualCost,
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate summary
|
||||
var totalActualCost, totalStandardCost float64
|
||||
var totalRequests, totalTokens int64
|
||||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||
|
||||
for i := range history {
|
||||
h := &history[i]
|
||||
totalActualCost += h.ActualCost
|
||||
totalStandardCost += h.Cost
|
||||
totalRequests += h.Requests
|
||||
totalTokens += h.Tokens
|
||||
|
||||
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
|
||||
highestCostDay = h
|
||||
}
|
||||
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
|
||||
highestRequestDay = h
|
||||
}
|
||||
}
|
||||
|
||||
actualDaysUsed := len(history)
|
||||
if actualDaysUsed == 0 {
|
||||
actualDaysUsed = 1
|
||||
}
|
||||
|
||||
// Get average duration
|
||||
var avgDuration struct {
|
||||
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Scan(&avgDuration)
|
||||
|
||||
summary := AccountUsageSummary{
|
||||
Days: daysCount,
|
||||
ActualDaysUsed: actualDaysUsed,
|
||||
TotalCost: totalActualCost,
|
||||
TotalStandardCost: totalStandardCost,
|
||||
TotalRequests: totalRequests,
|
||||
TotalTokens: totalTokens,
|
||||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
||||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||
AvgDurationMs: avgDuration.AvgDurationMs,
|
||||
}
|
||||
|
||||
// Set today's stats
|
||||
todayStr := timezone.Now().Format("2006-01-02")
|
||||
for i := range history {
|
||||
if history[i].Date == todayStr {
|
||||
summary.Today = &struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}{
|
||||
Date: history[i].Date,
|
||||
Cost: history[i].ActualCost,
|
||||
Requests: history[i].Requests,
|
||||
Tokens: history[i].Tokens,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Set highest cost day
|
||||
if highestCostDay != nil {
|
||||
summary.HighestCostDay = &struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
}{
|
||||
Date: highestCostDay.Date,
|
||||
Label: highestCostDay.Label,
|
||||
Cost: highestCostDay.ActualCost,
|
||||
Requests: highestCostDay.Requests,
|
||||
}
|
||||
}
|
||||
|
||||
// Set highest request day
|
||||
if highestRequestDay != nil {
|
||||
summary.HighestRequestDay = &struct {
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
}{
|
||||
Date: highestRequestDay.Date,
|
||||
Label: highestRequestDay.Label,
|
||||
Requests: highestRequestDay.Requests,
|
||||
Cost: highestRequestDay.ActualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// Get model statistics using the unified method
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
|
||||
return &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -66,17 +66,47 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("email ILIKE ?", searchPattern)
|
||||
db = db.Where(
|
||||
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
|
||||
searchPattern, searchPattern, searchPattern,
|
||||
)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Query users with pagination (reuse the same db with filters applied)
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Batch load subscriptions for all users (avoid N+1)
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]int64, len(users))
|
||||
userMap := make(map[int64]*model.User, len(users))
|
||||
for i := range users {
|
||||
userIDs[i] = users[i].ID
|
||||
userMap[users[i].ID] = &users[i]
|
||||
}
|
||||
|
||||
// Query active subscriptions with groups in one query
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Associate subscriptions with users
|
||||
for i := range subscriptions {
|
||||
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
||||
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
@@ -36,7 +36,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
NewClaudeUpstream,
|
||||
NewHTTPUpstream,
|
||||
NewOpenAIOAuthClient,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -82,6 +82,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
@@ -179,6 +180,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
@@ -191,8 +193,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
|
||||
// OAuth routes
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
@@ -201,6 +205,16 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// OpenAI OAuth routes
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
@@ -289,5 +303,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -14,15 +14,19 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -36,17 +40,19 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
claudeUpstream ClaudeUpstream
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +120,18 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Route to platform-specific test method
|
||||
if account.IsOpenAI() {
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
@@ -222,7 +240,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -234,11 +252,155 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processStream(c, resp.Body)
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// processStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = openai.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var apiURL string
|
||||
var isOAuth bool
|
||||
var chatgptAccountID string
|
||||
|
||||
if account.IsOAuth() {
|
||||
isOAuth = true
|
||||
// OAuth - use Bearer token with ChatGPT internal API
|
||||
authToken = account.GetOpenAIAccessToken()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token is expired and refresh if needed
|
||||
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
||||
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
|
||||
// OAuth uses ChatGPT internal API
|
||||
apiURL = chatgptCodexAPIURL
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create OpenAI Responses API payload
|
||||
payload := createOpenAITestPayload(testModelID, isOAuth)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
|
||||
// Set OAuth-specific headers for ChatGPT internal API
|
||||
if isOAuth {
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
// OAuth accounts using ChatGPT internal API require store: false
|
||||
if isOAuth {
|
||||
payload["store"] = false
|
||||
}
|
||||
|
||||
// All accounts require instructions for Responses API
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// processClaudeStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
@@ -291,6 +453,59 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "response.output_text.delta":
|
||||
// OpenAI Responses API uses "delta" field for text content
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
@@ -176,6 +177,14 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account usage stats failed: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -22,7 +22,7 @@ type AdminService interface {
|
||||
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
@@ -45,6 +45,7 @@ type AdminService interface {
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
||||
@@ -71,6 +72,9 @@ type AdminService interface {
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Wechat string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
@@ -79,6 +83,9 @@ type CreateUserInput struct {
|
||||
type UpdateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username *string
|
||||
Wechat *string
|
||||
Notes *string
|
||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
@@ -134,6 +141,33 @@ type UpdateAccountInput struct {
|
||||
GroupIDs *[]int64
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
type BulkUpdateAccountResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
}
|
||||
|
||||
type CreateProxyInput struct {
|
||||
Name string
|
||||
Protocol string
|
||||
@@ -237,6 +271,9 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User,
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
||||
user := &model.User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Wechat: input.Wechat,
|
||||
Notes: input.Notes,
|
||||
Role: "user", // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
@@ -262,8 +299,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, errors.New("cannot disable admin user")
|
||||
}
|
||||
|
||||
// Track balance and concurrency changes for logging
|
||||
oldBalance := user.Balance
|
||||
oldConcurrency := user.Concurrency
|
||||
|
||||
if input.Email != "" {
|
||||
@@ -274,22 +309,25 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Role is not allowed to be changed via API to prevent privilege escalation
|
||||
|
||||
if input.Username != nil {
|
||||
user.Username = *input.Username
|
||||
}
|
||||
if input.Wechat != nil {
|
||||
user.Wechat = *input.Wechat
|
||||
}
|
||||
if input.Notes != nil {
|
||||
user.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
user.Status = input.Status
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Balance(支持设置为 0)
|
||||
if input.Balance != nil {
|
||||
user.Balance = *input.Balance
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为任意值)
|
||||
if input.Concurrency != nil {
|
||||
user.Concurrency = *input.Concurrency
|
||||
}
|
||||
|
||||
// 只在指针非 nil 时更新 AllowedGroups
|
||||
if input.AllowedGroups != nil {
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
@@ -298,41 +336,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 余额变化时失效缓存
|
||||
if input.Balance != nil && *input.Balance != oldBalance {
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Create adjustment records for balance/concurrency changes
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
@@ -369,12 +372,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return s.userRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldBalance := user.Balance
|
||||
|
||||
switch operation {
|
||||
case "set":
|
||||
user.Balance = balance
|
||||
@@ -384,11 +389,14 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
user.Balance -= balance
|
||||
}
|
||||
|
||||
if user.Balance < 0 {
|
||||
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -399,6 +407,30 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}()
|
||||
}
|
||||
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
Notes: notes,
|
||||
}
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -690,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
result := &BulkUpdateAccountsResult{
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
}
|
||||
|
||||
if len(input.AccountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := ports.AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
Extra: input.Extra,
|
||||
}
|
||||
if input.Name != "" {
|
||||
repoUpdates.Name = &input.Name
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
repoUpdates.ProxyID = input.ProxyID
|
||||
}
|
||||
if input.Concurrency != nil {
|
||||
repoUpdates.Concurrency = input.Concurrency
|
||||
}
|
||||
if input.Priority != nil {
|
||||
repoUpdates.Priority = input.Priority
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
|
||||
// Run bulk update for column/jsonb fields first.
|
||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle group bindings per account (requires individual operations).
|
||||
for _, accountID := range input.AccountIDs {
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
entry.Success = true
|
||||
result.Success++
|
||||
result.Results = append(result.Results, entry)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -455,3 +455,11 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"log"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -23,6 +23,7 @@ var (
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
ErrServiceUnavailable = errors.New("service temporarily unavailable")
|
||||
)
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
@@ -90,7 +91,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return "", nil, ErrEmailExists
|
||||
@@ -121,7 +123,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return "", nil, fmt.Errorf("create user: %w", err)
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 生成token
|
||||
@@ -148,7 +151,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return ErrEmailExists
|
||||
@@ -181,8 +185,8 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Error checking email exists: %v", err)
|
||||
return nil, fmt.Errorf("check email exists: %w", err)
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
log.Printf("[Auth] Email already exists: %s", email)
|
||||
@@ -254,7 +258,9 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
return "", nil, fmt.Errorf("get user by email: %w", err)
|
||||
// 记录数据库错误但不暴露给用户
|
||||
log.Printf("[Auth] Database error during login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
@@ -354,7 +360,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
return "", fmt.Errorf("get user: %w", err)
|
||||
log.Printf("[Auth] Database error refreshing token: %v", err)
|
||||
return "", ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
|
||||
@@ -2,9 +2,9 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"log"
|
||||
"strings"
|
||||
"sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
|
||||
@@ -2,12 +2,26 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
// Uses 8 random bytes (16 hex chars) for uniqueness
|
||||
func generateRequestID() string {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to nanosecond timestamp (extremely rare case)
|
||||
return fmt.Sprintf("%x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
@@ -41,7 +55,10 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -52,8 +69,8 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
@@ -77,7 +94,10 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,8 +108,8 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
@@ -147,3 +167,20 @@ func CalculateMaxWait(userConcurrency int) int {
|
||||
}
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int)
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
|
||||
if err != nil {
|
||||
// If key doesn't exist in Redis, count is 0
|
||||
count = 0
|
||||
}
|
||||
result[accountID] = count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
961
backend/internal/service/crs_sync_service.go
Normal file
961
backend/internal/service/crs_sync_service.go
Normal file
@@ -0,0 +1,961 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
type SyncFromCRSInput struct {
|
||||
BaseURL string
|
||||
Username string
|
||||
Password string
|
||||
SyncProxies bool
|
||||
}
|
||||
|
||||
type SyncFromCRSItemResult struct {
|
||||
CRSAccountID string `json:"crs_account_id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Action string `json:"action"` // created/updated/failed/skipped
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type SyncFromCRSResult struct {
|
||||
Created int `json:"created"`
|
||||
Updated int `json:"updated"`
|
||||
Skipped int `json:"skipped"`
|
||||
Failed int `json:"failed"`
|
||||
Items []SyncFromCRSItemResult `json:"items"`
|
||||
}
|
||||
|
||||
type crsLoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type crsExportResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
ExportedAt string `json:"exportedAt"`
|
||||
ClaudeAccounts []crsClaudeAccount `json:"claudeAccounts"`
|
||||
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type crsProxy struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type crsClaudeAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
AuthType string `json:"authType"` // oauth/setup-token
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type crsConsoleAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
MaxConcurrentTasks int `json:"maxConcurrentTasks"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
}
|
||||
|
||||
type crsOpenAIResponsesAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
}
|
||||
|
||||
type crsOpenAIOAuthAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
AuthType string `json:"authType"` // oauth
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
result := &SyncFromCRSResult{
|
||||
Items: make(
|
||||
[]SyncFromCRSItemResult,
|
||||
0,
|
||||
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
|
||||
),
|
||||
}
|
||||
|
||||
var proxies []model.Proxy
|
||||
if input.SyncProxies {
|
||||
proxies, _ = s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// Claude OAuth / Setup Token -> sub2api anthropic oauth/setup-token
|
||||
for _, src := range exported.Data.ClaudeAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
targetType := strings.TrimSpace(src.AuthType)
|
||||
if targetType == "" {
|
||||
targetType = "oauth"
|
||||
}
|
||||
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
|
||||
item.Action = "skipped"
|
||||
item.Error = "unsupported authType: " + targetType
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
accessToken, _ := src.Credentials["access_token"].(string)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing access_token"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
// 🔧 Remove /v1 suffix from base_url for Claude accounts
|
||||
cleanBaseURL(credentials, "/v1")
|
||||
// 🔧 Convert expires_at from ISO string to Unix timestamp
|
||||
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
credentials["expires_at"] = t.Unix()
|
||||
}
|
||||
}
|
||||
// 🔧 Add intercept_warmup_requests if not present (defaults to false)
|
||||
if _, exists := credentials["intercept_warmup_requests"]; !exists {
|
||||
credentials["intercept_warmup_requests"] = false
|
||||
}
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
// 🔧 Preserve all CRS extra fields and add sync metadata
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
// Extract org_uuid and account_uuid from CRS credentials to extra
|
||||
if orgUUID, ok := src.Credentials["org_uuid"]; ok {
|
||||
extra["org_uuid"] = orgUUID
|
||||
}
|
||||
if accountUUID, ok := src.Credentials["account_uuid"]; ok {
|
||||
extra["account_uuid"] = accountUUID
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: targetType,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update existing
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Type = targetType
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
}
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// Claude Console API Key -> sub2api anthropic apikey
|
||||
for _, src := range exported.Data.ClaudeConsoleAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
apiKey, _ := src.Credentials["api_key"].(string)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing api_key"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
if src.MaxConcurrentTasks > 0 {
|
||||
concurrency = src.MaxConcurrentTasks
|
||||
}
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
extra := map[string]any{
|
||||
"crs_account_id": src.ID,
|
||||
"crs_kind": src.Kind,
|
||||
"crs_synced_at": now,
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// OpenAI OAuth -> sub2api openai oauth
|
||||
for _, src := range exported.Data.OpenAIOAuthAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
accessToken, _ := src.Credentials["access_token"].(string)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing access_token"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(
|
||||
ctx,
|
||||
input.SyncProxies,
|
||||
&proxies,
|
||||
src.Proxy,
|
||||
fmt.Sprintf("crs-%s", src.Name),
|
||||
)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
// Normalize token_type
|
||||
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
|
||||
credentials["token_type"] = "Bearer"
|
||||
}
|
||||
// 🔧 Convert expires_at from ISO string to Unix timestamp
|
||||
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
credentials["expires_at"] = t.Unix()
|
||||
}
|
||||
}
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
// 🔧 Preserve all CRS extra fields and add sync metadata
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
// Extract email from CRS extra (crs_email -> email)
|
||||
if crsEmail, ok := src.Extra["crs_email"]; ok {
|
||||
extra["email"] = crsEmail
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeOAuth
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// OpenAI Responses API Key -> sub2api openai apikey
|
||||
for _, src := range exported.Data.OpenAIResponsesAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
apiKey, _ := src.Credentials["api_key"].(string)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing api_key"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if baseURL, ok := src.Credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
|
||||
src.Credentials["base_url"] = "https://api.openai.com"
|
||||
}
|
||||
// 🔧 Remove /v1 suffix from base_url for OpenAI accounts
|
||||
cleanBaseURL(src.Credentials, "/v1")
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(
|
||||
ctx,
|
||||
input.SyncProxies,
|
||||
&proxies,
|
||||
src.Proxy,
|
||||
fmt.Sprintf("crs-%s", src.Name),
|
||||
)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
priority := clampPriority(src.Priority)
|
||||
concurrency := 3
|
||||
status := mapCRSStatus(src.IsActive, src.Status)
|
||||
|
||||
extra := map[string]any{
|
||||
"crs_account_id": src.ID,
|
||||
"crs_kind": src.Kind,
|
||||
"crs_synced_at": now,
|
||||
}
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
Status: status,
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = concurrency
|
||||
existing.Priority = priority
|
||||
existing.Status = status
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
|
||||
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
|
||||
out := make(model.JSONB)
|
||||
for k, v := range existing {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range updates {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
|
||||
if !enabled || src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
protocol := strings.ToLower(strings.TrimSpace(src.Protocol))
|
||||
switch protocol {
|
||||
case "socks":
|
||||
protocol = "socks5"
|
||||
case "socks5h":
|
||||
protocol = "socks5"
|
||||
}
|
||||
host := strings.TrimSpace(src.Host)
|
||||
port := src.Port
|
||||
username := strings.TrimSpace(src.Username)
|
||||
password := strings.TrimSpace(src.Password)
|
||||
|
||||
if protocol == "" || host == "" || port <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if protocol != "http" && protocol != "https" && protocol != "socks5" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find existing proxy (active only).
|
||||
for _, p := range *cached {
|
||||
if strings.EqualFold(p.Protocol, protocol) &&
|
||||
p.Host == host &&
|
||||
p.Port == port &&
|
||||
p.Username == username &&
|
||||
p.Password == password {
|
||||
id := p.ID
|
||||
return &id, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new proxy
|
||||
proxy := &model.Proxy{
|
||||
Name: defaultProxyName(defaultName, protocol, host, port),
|
||||
Protocol: protocol,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
*cached = append(*cached, *proxy)
|
||||
id := proxy.ID
|
||||
return &id, nil
|
||||
}
|
||||
|
||||
func defaultProxyName(base, protocol, host string, port int) string {
|
||||
base = strings.TrimSpace(base)
|
||||
if base == "" {
|
||||
base = "crs"
|
||||
}
|
||||
return fmt.Sprintf("%s (%s://%s:%d)", base, protocol, host, port)
|
||||
}
|
||||
|
||||
func defaultName(name, id string) string {
|
||||
if strings.TrimSpace(name) != "" {
|
||||
return strings.TrimSpace(name)
|
||||
}
|
||||
return "CRS " + id
|
||||
}
|
||||
|
||||
func clampPriority(priority int) int {
|
||||
if priority < 1 || priority > 100 {
|
||||
return 50
|
||||
}
|
||||
return priority
|
||||
}
|
||||
|
||||
func sanitizeCredentialsMap(input map[string]any) map[string]any {
|
||||
if input == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(input))
|
||||
for k, v := range input {
|
||||
// Avoid nil values to keep JSONB cleaner
|
||||
if v != nil {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapCRSStatus(isActive bool, status string) string {
|
||||
if !isActive {
|
||||
return "inactive"
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(status), "error") {
|
||||
return "error"
|
||||
}
|
||||
return "active"
|
||||
}
|
||||
|
||||
func normalizeBaseURL(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("base_url is required")
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "", fmt.Errorf("invalid base_url: %s", trimmed)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/")
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
}
|
||||
|
||||
// cleanBaseURL removes trailing suffix from base_url in credentials
|
||||
// Used for both Claude and OpenAI accounts to remove /v1
|
||||
func cleanBaseURL(credentials map[string]any, suffixToRemove string) {
|
||||
if baseURL, ok := credentials["base_url"].(string); ok && baseURL != "" {
|
||||
trimmed := strings.TrimSpace(baseURL)
|
||||
if strings.HasSuffix(trimmed, suffixToRemove) {
|
||||
credentials["base_url"] = strings.TrimSuffix(trimmed, suffixToRemove)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func crsLogin(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/web/auth/login", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("crs login failed: status=%d body=%s", resp.StatusCode, string(raw))
|
||||
}
|
||||
|
||||
var parsed crsLoginResponse
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return "", fmt.Errorf("crs login parse failed: %w", err)
|
||||
}
|
||||
if !parsed.Success || strings.TrimSpace(parsed.Token) == "" {
|
||||
msg := parsed.Message
|
||||
if msg == "" {
|
||||
msg = parsed.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return "", errors.New("crs login failed: " + msg)
|
||||
}
|
||||
return parsed.Token, nil
|
||||
}
|
||||
|
||||
func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminToken string) (*crsExportResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/admin/sync/export-accounts?include_secrets=true", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 5<<20))
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("crs export failed: status=%d body=%s", resp.StatusCode, string(raw))
|
||||
}
|
||||
|
||||
var parsed crsExportResponse
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("crs export parse failed: %w", err)
|
||||
}
|
||||
if !parsed.Success {
|
||||
msg := parsed.Message
|
||||
if msg == "" {
|
||||
msg = parsed.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return nil, errors.New("crs export failed: " + msg)
|
||||
}
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
// refreshOAuthToken attempts to refresh OAuth token for a synced account
|
||||
// Returns updated credentials or nil if refresh failed/not applicable
|
||||
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
|
||||
if account.Type != model.AccountTypeOAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
var err error
|
||||
|
||||
switch account.Platform {
|
||||
case model.PlatformAnthropic:
|
||||
if s.oauthService == nil {
|
||||
return nil
|
||||
}
|
||||
tokenInfo, refreshErr := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else {
|
||||
// Preserve existing credentials
|
||||
newCredentials = make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
// Update token fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
}
|
||||
case model.PlatformOpenAI:
|
||||
if s.openaiOAuthService == nil {
|
||||
return nil
|
||||
}
|
||||
tokenInfo, refreshErr := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else {
|
||||
newCredentials = s.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log but don't fail the sync - token might still be valid or refreshable later
|
||||
return nil
|
||||
}
|
||||
|
||||
return model.JSONB(newCredentials)
|
||||
}
|
||||
77
backend/internal/service/dashboard_service.go
Normal file
77
backend/internal/service/dashboard_service.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// DashboardService provides aggregated statistics for admin dashboard.
|
||||
type DashboardService struct {
|
||||
usageRepo ports.UsageLogRepository
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo ports.UsageLogRepository) *DashboardService {
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.usageRepo.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,19 +16,16 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ClaudeUpstream handles HTTP requests to Claude API
|
||||
type ClaudeUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
@@ -87,7 +84,7 @@ type GatewayService struct {
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
claudeUpstream ClaudeUpstream
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -102,7 +99,7 @@ func NewGatewayService(
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
claudeUpstream ClaudeUpstream,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -115,7 +112,7 @@ func NewGatewayService(
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,13 +282,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号)
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -363,6 +360,25 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// 重试相关常量
|
||||
const (
|
||||
maxRetries = 3 // 最大重试次数
|
||||
retryDelay = 2 * time.Second // 重试等待时间
|
||||
)
|
||||
|
||||
// shouldRetryUpstreamError 判断是否应该重试上游错误
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
// API Key 账号:未配置的错误码重试
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
if account.IsOAuth() {
|
||||
return statusCode == 403
|
||||
}
|
||||
|
||||
// API Key 账号:未配置的错误码重试
|
||||
return !account.ShouldHandleErrorCode(statusCode)
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -376,6 +392,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
if !gjson.GetBytes(body, "system").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
@@ -394,26 +422,51 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取代理URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要重试
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
_ = resp.Body.Close()
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
// 最后一次尝试也失败,跳出循环处理重试耗尽
|
||||
break
|
||||
}
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
// 处理重试耗尽的情况
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -481,7 +534,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -502,8 +555,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// 确保必要的headers存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
@@ -575,19 +628,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,返回通用 500 错误(不做任何账号状态处理)
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream gateway error",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
|
||||
@@ -596,6 +636,9 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
var statusCode int
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 400:
|
||||
c.Data(http.StatusBadRequest, "application/json", body)
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
@@ -634,6 +677,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||
// OAuth 403:标记账号异常
|
||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
statusCode := resp.StatusCode
|
||||
|
||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||
if account.IsOAuth() && statusCode == 403 {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
|
||||
} else {
|
||||
// API Key 未配置错误码:不标记账号状态
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
||||
}
|
||||
|
||||
// 返回统一的重试耗尽错误响应
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed after retries",
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode)
|
||||
}
|
||||
|
||||
// streamingResult 流式响应结果
|
||||
type streamingResult struct {
|
||||
usage *ClaudeUsage
|
||||
@@ -739,7 +810,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析message_start获取input tokens
|
||||
// 解析message_start获取input tokens(标准Claude API格式)
|
||||
var msgStart struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
@@ -752,15 +823,30 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析message_delta获取output tokens
|
||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||
var msgDelta struct {
|
||||
Type string `json:"type"`
|
||||
Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
||||
// output_tokens 总是从 message_delta 获取
|
||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||
|
||||
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
|
||||
if usage.InputTokens == 0 {
|
||||
usage.InputTokens = msgDelta.Usage.InputTokens
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 {
|
||||
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -982,7 +1068,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
@@ -1049,7 +1135,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -1073,8 +1159,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// 确保必要的 headers 存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -114,12 +114,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerpr
|
||||
return
|
||||
}
|
||||
|
||||
// 设置User-Agent
|
||||
// 设置user-agent
|
||||
if fp.UserAgent != "" {
|
||||
req.Header.Set("User-Agent", fp.UserAgent)
|
||||
req.Header.Set("user-agent", fp.UserAgent)
|
||||
}
|
||||
|
||||
// 设置x-stainless-*头(使用正确的大小写)
|
||||
// 设置x-stainless-*头
|
||||
if fp.StainlessLang != "" {
|
||||
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
@@ -284,3 +284,8 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
836
backend/internal/service/openai_gateway_service.go
Normal file
836
backend/internal/service/openai_gateway_service.go
Normal file
@@ -0,0 +1,836 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChatGPT internal API for OAuth accounts
|
||||
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"content-type": true,
|
||||
"user-agent": true,
|
||||
"originator": true,
|
||||
"session_id": true,
|
||||
}
|
||||
|
||||
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
|
||||
type OpenAICodexUsageSnapshot struct {
|
||||
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
|
||||
PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"`
|
||||
PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"`
|
||||
SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"`
|
||||
SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"`
|
||||
SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"`
|
||||
PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
sessionID := c.GetHeader("session_id")
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel selects an account supporting the requested model
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *model.Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
|
||||
}
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case model.AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
// Apply model mapping
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
var err error
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case model.AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/v1/responses"
|
||||
} else {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
default:
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set authentication header
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
||||
req.Host = "chatgpt.com"
|
||||
// Required: set chatgpt-account-id header
|
||||
chatgptAccountID := account.GetChatGPTAccountID()
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
// Set accept header based on stream mode
|
||||
if isStream {
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist passthrough headers
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom User-Agent if configured
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// Check custom error codes
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream gateway error",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
var statusCode int
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
statusCode = http.StatusTooManyRequests
|
||||
errType = "rate_limit_error"
|
||||
errMsg = "Upstream rate limit exceeded, please retry later"
|
||||
default:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream request failed"
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": errMsg,
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// openaiStreamingResult streaming response result
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// Pass through other headers
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace && strings.HasPrefix(line, "data: ") {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Parse usage data
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := line[6:]
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
data := line[6:]
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
}
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
||||
usage.InputTokens = event.Response.Usage.InputTokens
|
||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse usage
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
|
||||
// Replace model in response if needed
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 计算实际的新输入token(减去缓存读取的token)
|
||||
// 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费
|
||||
actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens
|
||||
if actualInputTokens < 0 {
|
||||
actualInputTokens = 0
|
||||
}
|
||||
|
||||
// Calculate cost
|
||||
tokens := UsageTokens{
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
|
||||
// Determine billing type
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
}
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Update account last used
|
||||
_ = s.accountRepo.UpdateLastUsed(ctx, account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
snapshot := &OpenAICodexUsageSnapshot{}
|
||||
hasData := false
|
||||
|
||||
// Helper to parse float64 from header
|
||||
parseFloat := func(key string) *float64 {
|
||||
if v := headers.Get(key); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return &f
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper to parse int from header
|
||||
parseInt := func(key string) *int {
|
||||
if v := headers.Get(key); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return &i
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Primary (weekly) limits
|
||||
if v := parseFloat("x-codex-primary-used-percent"); v != nil {
|
||||
snapshot.PrimaryUsedPercent = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil {
|
||||
snapshot.PrimaryResetAfterSeconds = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-primary-window-minutes"); v != nil {
|
||||
snapshot.PrimaryWindowMinutes = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
// Secondary (5h) limits
|
||||
if v := parseFloat("x-codex-secondary-used-percent"); v != nil {
|
||||
snapshot.SecondaryUsedPercent = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil {
|
||||
snapshot.SecondaryResetAfterSeconds = v
|
||||
hasData = true
|
||||
}
|
||||
if v := parseInt("x-codex-secondary-window-minutes"); v != nil {
|
||||
snapshot.SecondaryWindowMinutes = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
// Overflow ratio
|
||||
if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil {
|
||||
snapshot.PrimaryOverSecondaryPercent = v
|
||||
hasData = true
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshot.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
if snapshot.PrimaryOverSecondaryPercent != nil {
|
||||
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}()
|
||||
}
|
||||
257
backend/internal/service/openai_oauth_service.go
Normal file
257
backend/internal/service/openai_oauth_service.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthClient ports.OpenAIOAuthClient
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||||
func NewOpenAIOAuthService(proxyRepo ports.ProxyRepository, oauthClient ports.OpenAIOAuthClient) *OpenAIOAuthService {
|
||||
return &OpenAIOAuthService{
|
||||
sessionStore: openai.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||||
type OpenAIAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use default redirect URI if not specified
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeInput represents the input for code exchange
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// OpenAITokenInfo represents the token information for OpenAI
|
||||
type OpenAITokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use redirect URI from session or input
|
||||
redirectURI := session.RedirectURI
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
creds["id_token"] = tokenInfo.IDToken
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ChatGPTAccountID != "" {
|
||||
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
||||
}
|
||||
if tokenInfo.ChatGPTUserID != "" {
|
||||
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
||||
}
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
@@ -4,13 +4,16 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
@@ -27,9 +30,25 @@ type AccountRepository interface {
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package ports
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
|
||||
@@ -3,17 +3,21 @@ package ports
|
||||
import "context"
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
|
||||
type ConcurrencyCache interface {
|
||||
// Slot management
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64) error
|
||||
// Account slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:account:{accountID}:{requestID}
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64) error
|
||||
// User slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:user:{userID}:{requestID}
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue
|
||||
// Wait queue - uses counter with TTL set only on creation
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user