mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a9c17b9d1 | ||
|
|
799b010631 | ||
|
|
2d83941aaa | ||
|
|
f6a9a0a45a | ||
|
|
5b8d4fb047 | ||
|
|
afcfbb458d | ||
|
|
8f24d239af | ||
|
|
b7a29a4bac | ||
|
|
a42105881f | ||
|
|
958ffe7a8a | ||
|
|
b46b3c5c3c | ||
|
|
fd1b14fd1d | ||
|
|
eb198e5969 | ||
|
|
70fcbd7006 | ||
|
|
b015a3bd8a | ||
|
|
3fb43b91bf | ||
|
|
6e8188ed64 | ||
|
|
db6f53e2c9 | ||
|
|
acabdc2f99 | ||
|
|
169aa4716e | ||
|
|
c0753320a0 | ||
|
|
38d875b06f | ||
|
|
1ada6cf768 | ||
|
|
2b528c5f81 | ||
|
|
f6dd4752e7 | ||
|
|
b19c7875a4 | ||
|
|
d99a3ef14b | ||
|
|
fc8fa83fcc | ||
|
|
6dcd99468b | ||
|
|
d5ba7b80d3 | ||
|
|
a3b81ef7bc | ||
|
|
015974a27e | ||
|
|
4cf756ebe6 | ||
|
|
823497a2af | ||
|
|
66fe484f0d | ||
|
|
216321aa9e | ||
|
|
5a52cb608c | ||
|
|
1181b332f7 | ||
|
|
58b1777198 | ||
|
|
0c7a58fcc7 | ||
|
|
17ae51c0a0 | ||
|
|
4790aced15 | ||
|
|
3f0017d1f1 | ||
|
|
cb72262ad8 | ||
|
|
195e227c04 | ||
|
|
f5603b0780 | ||
|
|
752882a022 | ||
|
|
9731b961d0 | ||
|
|
2920409404 | ||
|
|
7dbbfc22b6 | ||
|
|
aaaa68ea7f | ||
|
|
af753de481 | ||
|
|
3956819c78 | ||
|
|
6fa704d6fc | ||
|
|
0400fcdca4 | ||
|
|
5b1907fe61 | ||
|
|
d4c2b723a5 |
16
.github/audit-exceptions.yml
vendored
Normal file
16
.github/audit-exceptions.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
version: 1
|
||||
exceptions:
|
||||
- package: xlsx
|
||||
advisory: "GHSA-4r6h-8v6p-xvw6"
|
||||
severity: high
|
||||
reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2023-30533)"
|
||||
mitigation: "Load only on export; restrict export permissions and data scope"
|
||||
expires_on: "2026-04-05"
|
||||
owner: "security@your-domain"
|
||||
- package: xlsx
|
||||
advisory: "GHSA-5pgg-2g8v-p4x9"
|
||||
severity: high
|
||||
reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2024-22363)"
|
||||
mitigation: "Load only on export; restrict export permissions and data scope"
|
||||
expires_on: "2026-04-05"
|
||||
owner: "security@your-domain"
|
||||
10
.github/workflows/backend-ci.yml
vendored
10
.github/workflows/backend-ci.yml
vendored
@@ -15,8 +15,11 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
check-latest: false
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.5'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -31,8 +34,11 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: true
|
||||
check-latest: false
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.5'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
|
||||
7
.github/workflows/release.yml
vendored
7
.github/workflows/release.yml
vendored
@@ -109,9 +109,14 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24'
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache-dependency-path: backend/go.sum
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.5'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
62
.github/workflows/security-scan.yml
vendored
Normal file
62
.github/workflows/security-scan.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: Security Scan
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: '0 3 * * 1'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
backend-security:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.5'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
govulncheck ./...
|
||||
- name: Run gosec
|
||||
working-directory: backend
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -severity high -confidence high ./...
|
||||
|
||||
frontend-security:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'pnpm'
|
||||
cache-dependency-path: frontend/pnpm-lock.yaml
|
||||
- name: Install dependencies
|
||||
working-directory: frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Run pnpm audit
|
||||
working-directory: frontend
|
||||
run: |
|
||||
pnpm audit --prod --audit-level=high --json > audit.json || true
|
||||
- name: Check audit exceptions
|
||||
run: |
|
||||
python tools/check_pnpm_audit_exceptions.py \
|
||||
--audit frontend/audit.json \
|
||||
--exceptions .github/audit-exceptions.yml
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -123,3 +123,6 @@ backend/cmd/server/server
|
||||
deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
vite.config.js
|
||||
!docs/
|
||||
docs/*
|
||||
!docs/dependency-security.md
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.19
|
||||
ARG GOLANG_IMAGE=golang:1.25.5-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.20
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
|
||||
6
Makefile
6
Makefile
@@ -9,7 +9,7 @@ build-backend:
|
||||
|
||||
# 编译前端(需要已安装依赖)
|
||||
build-frontend:
|
||||
@npm --prefix frontend run build
|
||||
@pnpm --dir frontend run build
|
||||
|
||||
# 运行测试(后端 + 前端)
|
||||
test: test-backend test-frontend
|
||||
@@ -18,5 +18,5 @@ test-backend:
|
||||
@$(MAKE) -C backend test
|
||||
|
||||
test-frontend:
|
||||
@npm --prefix frontend run lint:check
|
||||
@npm --prefix frontend run typecheck
|
||||
@pnpm --dir frontend run lint:check
|
||||
@pnpm --dir frontend run typecheck
|
||||
|
||||
63
README.md
63
README.md
@@ -2,7 +2,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
@@ -44,13 +44,19 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Backend | Go 1.21+, Gin, GORM |
|
||||
| Backend | Go 1.25.5, Gin, Ent |
|
||||
| Frontend | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| Database | PostgreSQL 15+ |
|
||||
| Cache/Queue | Redis 7+ |
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
- Dependency Security: `docs/dependency-security.md`
|
||||
|
||||
---
|
||||
|
||||
## Deployment
|
||||
|
||||
### Method 1: Script Installation (Recommended)
|
||||
@@ -160,6 +166,22 @@ ADMIN_PASSWORD=your_admin_password
|
||||
|
||||
# Optional: Custom port
|
||||
SERVER_PORT=8080
|
||||
|
||||
# Optional: Security configuration
|
||||
# Enable URL allowlist validation (false to skip allowlist checks, only basic format validation)
|
||||
SECURITY_URL_ALLOWLIST_ENABLED=false
|
||||
|
||||
# Allow insecure HTTP URLs when allowlist is disabled (default: false, requires https)
|
||||
# ⚠️ WARNING: Enabling this allows HTTP (plaintext) URLs which can expose API keys
|
||||
# Only recommended for:
|
||||
# - Development/testing environments
|
||||
# - Internal networks with trusted endpoints
|
||||
# - When using local test servers (http://localhost)
|
||||
# PRODUCTION: Keep this false or use HTTPS URLs only
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
|
||||
|
||||
# Allow private IP addresses for upstream/pricing/CRS (for internal deployments)
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
|
||||
```
|
||||
|
||||
```bash
|
||||
@@ -276,13 +298,48 @@ Additional security-related options are available in `config.yaml`:
|
||||
- `cors.allowed_origins` for CORS allowlist
|
||||
- `security.url_allowlist` for upstream/pricing/CRS host allowlists
|
||||
- `security.url_allowlist.enabled` to disable URL validation (use with caution)
|
||||
- `security.url_allowlist.allow_insecure_http` to allow http URLs when validation is disabled
|
||||
- `security.url_allowlist.allow_insecure_http` to allow HTTP URLs when validation is disabled
|
||||
- `security.url_allowlist.allow_private_hosts` to allow private/local IP addresses
|
||||
- `security.response_headers.enabled` to enable configurable response header filtering (disabled uses default allowlist)
|
||||
- `security.csp` to control Content-Security-Policy headers
|
||||
- `billing.circuit_breaker` to fail closed on billing errors
|
||||
- `server.trusted_proxies` to enable X-Forwarded-For parsing
|
||||
- `turnstile.required` to require Turnstile in release mode
|
||||
|
||||
**⚠️ Security Warning: HTTP URL Configuration**
|
||||
|
||||
When `security.url_allowlist.enabled=false`, the system performs minimal URL validation by default, **rejecting HTTP URLs** and only allowing HTTPS. To allow HTTP URLs (e.g., for development or internal testing), you must explicitly set:
|
||||
|
||||
```yaml
|
||||
security:
|
||||
url_allowlist:
|
||||
enabled: false # Disable allowlist checks
|
||||
allow_insecure_http: true # Allow HTTP URLs (⚠️ INSECURE)
|
||||
```
|
||||
|
||||
**Or via environment variable:**
|
||||
|
||||
```bash
|
||||
SECURITY_URL_ALLOWLIST_ENABLED=false
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=true
|
||||
```
|
||||
|
||||
**Risks of allowing HTTP:**
|
||||
- API keys and data transmitted in **plaintext** (vulnerable to interception)
|
||||
- Susceptible to **man-in-the-middle (MITM) attacks**
|
||||
- **NOT suitable for production** environments
|
||||
|
||||
**When to use HTTP:**
|
||||
- ✅ Development/testing with local servers (http://localhost)
|
||||
- ✅ Internal networks with trusted endpoints
|
||||
- ✅ Testing account connectivity before obtaining HTTPS
|
||||
- ❌ Production environments (use HTTPS only)
|
||||
|
||||
**Example error without this setting:**
|
||||
```
|
||||
Invalid base URL: invalid url scheme: http
|
||||
```
|
||||
|
||||
If you disable URL validation or response header filtering, harden your network layer:
|
||||
- Enforce an egress allowlist for upstream domains/IPs
|
||||
- Block private/loopback/link-local ranges
|
||||
|
||||
63
README_CN.md
63
README_CN.md
@@ -2,7 +2,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
@@ -44,13 +44,19 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
| 组件 | 技术 |
|
||||
|------|------|
|
||||
| 后端 | Go 1.21+, Gin, GORM |
|
||||
| 后端 | Go 1.25.5, Gin, Ent |
|
||||
| 前端 | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| 数据库 | PostgreSQL 15+ |
|
||||
| 缓存/队列 | Redis 7+ |
|
||||
|
||||
---
|
||||
|
||||
## 文档
|
||||
|
||||
- 依赖安全:`docs/dependency-security.md`
|
||||
|
||||
---
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
@@ -160,6 +166,22 @@ ADMIN_PASSWORD=your_admin_password
|
||||
|
||||
# 可选:自定义端口
|
||||
SERVER_PORT=8080
|
||||
|
||||
# 可选:安全配置
|
||||
# 启用 URL 白名单验证(false 则跳过白名单检查,仅做基本格式校验)
|
||||
SECURITY_URL_ALLOWLIST_ENABLED=false
|
||||
|
||||
# 关闭白名单时,是否允许 http:// URL(默认 false,只允许 https://)
|
||||
# ⚠️ 警告:允许 HTTP 会暴露 API 密钥(明文传输)
|
||||
# 仅建议在以下场景使用:
|
||||
# - 开发/测试环境
|
||||
# - 内部可信网络
|
||||
# - 本地测试服务器(http://localhost)
|
||||
# 生产环境:保持 false 或仅使用 HTTPS URL
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
|
||||
|
||||
# 是否允许私有 IP 地址用于上游/定价/CRS(内网部署时使用)
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
|
||||
```
|
||||
|
||||
```bash
|
||||
@@ -276,13 +298,48 @@ default:
|
||||
- `cors.allowed_origins` 配置 CORS 白名单
|
||||
- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
|
||||
- `security.url_allowlist.enabled` 可关闭 URL 校验(慎用)
|
||||
- `security.url_allowlist.allow_insecure_http` 关闭校验时允许 http URL
|
||||
- `security.url_allowlist.allow_insecure_http` 关闭校验时允许 HTTP URL
|
||||
- `security.url_allowlist.allow_private_hosts` 允许私有/本地 IP 地址
|
||||
- `security.response_headers.enabled` 可启用可配置响应头过滤(关闭时使用默认白名单)
|
||||
- `security.csp` 配置 Content-Security-Policy
|
||||
- `billing.circuit_breaker` 计费异常时 fail-closed
|
||||
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
||||
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
||||
|
||||
**⚠️ 安全警告:HTTP URL 配置**
|
||||
|
||||
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
||||
|
||||
```yaml
|
||||
security:
|
||||
url_allowlist:
|
||||
enabled: false # 禁用白名单检查
|
||||
allow_insecure_http: true # 允许 HTTP URL(⚠️ 不安全)
|
||||
```
|
||||
|
||||
**或通过环境变量:**
|
||||
|
||||
```bash
|
||||
SECURITY_URL_ALLOWLIST_ENABLED=false
|
||||
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=true
|
||||
```
|
||||
|
||||
**允许 HTTP 的风险:**
|
||||
- API 密钥和数据以**明文传输**(可被截获)
|
||||
- 易受**中间人攻击 (MITM)**
|
||||
- **不适合生产环境**
|
||||
|
||||
**适用场景:**
|
||||
- ✅ 开发/测试环境的本地服务器(http://localhost)
|
||||
- ✅ 内网可信端点
|
||||
- ✅ 获取 HTTPS 前测试账号连通性
|
||||
- ❌ 生产环境(仅使用 HTTPS)
|
||||
|
||||
**未设置此项时的错误示例:**
|
||||
```
|
||||
Invalid base URL: invalid url scheme: http
|
||||
```
|
||||
|
||||
如关闭 URL 校验或响应头过滤,请加强网络层防护:
|
||||
- 出站访问白名单限制上游域名/IP
|
||||
- 阻断私网/回环/链路本地地址
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.21-alpine
|
||||
FROM golang:1.25.5-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
@@ -84,6 +85,10 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -87,6 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
|
||||
@@ -97,15 +98,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -114,7 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -148,7 +148,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -174,6 +175,7 @@ func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
@@ -194,6 +196,10 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -49,6 +49,10 @@ type Account struct {
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
// LastUsedAt holds the value of the "last_used_at" field.
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
// Account expiration time (NULL means no expiration).
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
// Auto pause scheduling when account expires.
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired,omitempty"`
|
||||
// Schedulable holds the value of the "schedulable" field.
|
||||
Schedulable bool `json:"schedulable,omitempty"`
|
||||
// RateLimitedAt holds the value of the "rate_limited_at" field.
|
||||
@@ -129,13 +133,13 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case account.FieldCredentials, account.FieldExtra:
|
||||
values[i] = new([]byte)
|
||||
case account.FieldSchedulable:
|
||||
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
|
||||
case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@@ -257,6 +261,19 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
_m.LastUsedAt = new(time.Time)
|
||||
*_m.LastUsedAt = value.Time
|
||||
}
|
||||
case account.FieldExpiresAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ExpiresAt = new(time.Time)
|
||||
*_m.ExpiresAt = value.Time
|
||||
}
|
||||
case account.FieldAutoPauseOnExpired:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field auto_pause_on_expired", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AutoPauseOnExpired = value.Bool
|
||||
}
|
||||
case account.FieldSchedulable:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field schedulable", values[i])
|
||||
@@ -416,6 +433,14 @@ func (_m *Account) String() string {
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ExpiresAt; v != nil {
|
||||
builder.WriteString("expires_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("auto_pause_on_expired=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.AutoPauseOnExpired))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("schedulable=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Schedulable))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -45,6 +45,10 @@ const (
|
||||
FieldErrorMessage = "error_message"
|
||||
// FieldLastUsedAt holds the string denoting the last_used_at field in the database.
|
||||
FieldLastUsedAt = "last_used_at"
|
||||
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||
FieldExpiresAt = "expires_at"
|
||||
// FieldAutoPauseOnExpired holds the string denoting the auto_pause_on_expired field in the database.
|
||||
FieldAutoPauseOnExpired = "auto_pause_on_expired"
|
||||
// FieldSchedulable holds the string denoting the schedulable field in the database.
|
||||
FieldSchedulable = "schedulable"
|
||||
// FieldRateLimitedAt holds the string denoting the rate_limited_at field in the database.
|
||||
@@ -115,6 +119,8 @@ var Columns = []string{
|
||||
FieldStatus,
|
||||
FieldErrorMessage,
|
||||
FieldLastUsedAt,
|
||||
FieldExpiresAt,
|
||||
FieldAutoPauseOnExpired,
|
||||
FieldSchedulable,
|
||||
FieldRateLimitedAt,
|
||||
FieldRateLimitResetAt,
|
||||
@@ -172,6 +178,8 @@ var (
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultAutoPauseOnExpired holds the default value on creation for the "auto_pause_on_expired" field.
|
||||
DefaultAutoPauseOnExpired bool
|
||||
// DefaultSchedulable holds the default value on creation for the "schedulable" field.
|
||||
DefaultSchedulable bool
|
||||
// SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
@@ -251,6 +259,16 @@ func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByExpiresAt orders the results by the expires_at field.
|
||||
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAutoPauseOnExpired orders the results by the auto_pause_on_expired field.
|
||||
func ByAutoPauseOnExpired(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAutoPauseOnExpired, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySchedulable orders the results by the schedulable field.
|
||||
func BySchedulable(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSchedulable, opts...).ToFunc()
|
||||
|
||||
@@ -120,6 +120,16 @@ func LastUsedAt(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldLastUsedAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||
func ExpiresAt(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// AutoPauseOnExpired applies equality check predicate on the "auto_pause_on_expired" field. It's identical to AutoPauseOnExpiredEQ.
|
||||
func AutoPauseOnExpired(v bool) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldAutoPauseOnExpired, v))
|
||||
}
|
||||
|
||||
// Schedulable applies equality check predicate on the "schedulable" field. It's identical to SchedulableEQ.
|
||||
func Schedulable(v bool) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldSchedulable, v))
|
||||
@@ -855,6 +865,66 @@ func LastUsedAtNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldLastUsedAt))
|
||||
}
|
||||
|
||||
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||
func ExpiresAtEQ(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||
func ExpiresAtNEQ(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||
func ExpiresAtIn(vs ...time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||
func ExpiresAtNotIn(vs ...time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||
func ExpiresAtGT(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||
func ExpiresAtGTE(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||
func ExpiresAtLT(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||
func ExpiresAtLTE(v time.Time) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
|
||||
func ExpiresAtIsNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldIsNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
|
||||
func ExpiresAtNotNil() predicate.Account {
|
||||
return predicate.Account(sql.FieldNotNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// AutoPauseOnExpiredEQ applies the EQ predicate on the "auto_pause_on_expired" field.
|
||||
func AutoPauseOnExpiredEQ(v bool) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldAutoPauseOnExpired, v))
|
||||
}
|
||||
|
||||
// AutoPauseOnExpiredNEQ applies the NEQ predicate on the "auto_pause_on_expired" field.
|
||||
func AutoPauseOnExpiredNEQ(v bool) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldAutoPauseOnExpired, v))
|
||||
}
|
||||
|
||||
// SchedulableEQ applies the EQ predicate on the "schedulable" field.
|
||||
func SchedulableEQ(v bool) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldSchedulable, v))
|
||||
|
||||
@@ -195,6 +195,34 @@ func (_c *AccountCreate) SetNillableLastUsedAt(v *time.Time) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_c *AccountCreate) SetExpiresAt(v time.Time) *AccountCreate {
|
||||
_c.mutation.SetExpiresAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableExpiresAt(v *time.Time) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetExpiresAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (_c *AccountCreate) SetAutoPauseOnExpired(v bool) *AccountCreate {
|
||||
_c.mutation.SetAutoPauseOnExpired(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableAutoPauseOnExpired(v *bool) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetAutoPauseOnExpired(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (_c *AccountCreate) SetSchedulable(v bool) *AccountCreate {
|
||||
_c.mutation.SetSchedulable(v)
|
||||
@@ -405,6 +433,10 @@ func (_c *AccountCreate) defaults() error {
|
||||
v := account.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
}
|
||||
if _, ok := _c.mutation.AutoPauseOnExpired(); !ok {
|
||||
v := account.DefaultAutoPauseOnExpired
|
||||
_c.mutation.SetAutoPauseOnExpired(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Schedulable(); !ok {
|
||||
v := account.DefaultSchedulable
|
||||
_c.mutation.SetSchedulable(v)
|
||||
@@ -464,6 +496,9 @@ func (_c *AccountCreate) check() error {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Account.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.AutoPauseOnExpired(); !ok {
|
||||
return &ValidationError{Name: "auto_pause_on_expired", err: errors.New(`ent: missing required field "Account.auto_pause_on_expired"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.Schedulable(); !ok {
|
||||
return &ValidationError{Name: "schedulable", err: errors.New(`ent: missing required field "Account.schedulable"`)}
|
||||
}
|
||||
@@ -555,6 +590,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldLastUsedAt, field.TypeTime, value)
|
||||
_node.LastUsedAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(account.FieldExpiresAt, field.TypeTime, value)
|
||||
_node.ExpiresAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.AutoPauseOnExpired(); ok {
|
||||
_spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value)
|
||||
_node.AutoPauseOnExpired = value
|
||||
}
|
||||
if value, ok := _c.mutation.Schedulable(); ok {
|
||||
_spec.SetField(account.FieldSchedulable, field.TypeBool, value)
|
||||
_node.Schedulable = value
|
||||
@@ -898,6 +941,36 @@ func (u *AccountUpsert) ClearLastUsedAt() *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *AccountUpsert) SetExpiresAt(v time.Time) *AccountUpsert {
|
||||
u.Set(account.FieldExpiresAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateExpiresAt() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *AccountUpsert) ClearExpiresAt() *AccountUpsert {
|
||||
u.SetNull(account.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (u *AccountUpsert) SetAutoPauseOnExpired(v bool) *AccountUpsert {
|
||||
u.Set(account.FieldAutoPauseOnExpired, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateAutoPauseOnExpired() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldAutoPauseOnExpired)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (u *AccountUpsert) SetSchedulable(v bool) *AccountUpsert {
|
||||
u.Set(account.FieldSchedulable, v)
|
||||
@@ -1308,6 +1381,41 @@ func (u *AccountUpsertOne) ClearLastUsedAt() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *AccountUpsertOne) SetExpiresAt(v time.Time) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateExpiresAt() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *AccountUpsertOne) ClearExpiresAt() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (u *AccountUpsertOne) SetAutoPauseOnExpired(v bool) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetAutoPauseOnExpired(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateAutoPauseOnExpired() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateAutoPauseOnExpired()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (u *AccountUpsertOne) SetSchedulable(v bool) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -1904,6 +2012,41 @@ func (u *AccountUpsertBulk) ClearLastUsedAt() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *AccountUpsertBulk) SetExpiresAt(v time.Time) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateExpiresAt() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *AccountUpsertBulk) ClearExpiresAt() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (u *AccountUpsertBulk) SetAutoPauseOnExpired(v bool) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetAutoPauseOnExpired(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAutoPauseOnExpired sets the "auto_pause_on_expired" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateAutoPauseOnExpired() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateAutoPauseOnExpired()
|
||||
})
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (u *AccountUpsertBulk) SetSchedulable(v bool) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -247,6 +247,40 @@ func (_u *AccountUpdate) ClearLastUsedAt() *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *AccountUpdate) SetExpiresAt(v time.Time) *AccountUpdate {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableExpiresAt(v *time.Time) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *AccountUpdate) ClearExpiresAt() *AccountUpdate {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (_u *AccountUpdate) SetAutoPauseOnExpired(v bool) *AccountUpdate {
|
||||
_u.mutation.SetAutoPauseOnExpired(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetAutoPauseOnExpired(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (_u *AccountUpdate) SetSchedulable(v bool) *AccountUpdate {
|
||||
_u.mutation.SetSchedulable(v)
|
||||
@@ -610,6 +644,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.LastUsedAtCleared() {
|
||||
_spec.ClearField(account.FieldLastUsedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(account.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(account.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.AutoPauseOnExpired(); ok {
|
||||
_spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Schedulable(); ok {
|
||||
_spec.SetField(account.FieldSchedulable, field.TypeBool, value)
|
||||
}
|
||||
@@ -1016,6 +1059,40 @@ func (_u *AccountUpdateOne) ClearLastUsedAt() *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *AccountUpdateOne) SetExpiresAt(v time.Time) *AccountUpdateOne {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableExpiresAt(v *time.Time) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *AccountUpdateOne) ClearExpiresAt() *AccountUpdateOne {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAutoPauseOnExpired sets the "auto_pause_on_expired" field.
|
||||
func (_u *AccountUpdateOne) SetAutoPauseOnExpired(v bool) *AccountUpdateOne {
|
||||
_u.mutation.SetAutoPauseOnExpired(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAutoPauseOnExpired sets the "auto_pause_on_expired" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAutoPauseOnExpired(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSchedulable sets the "schedulable" field.
|
||||
func (_u *AccountUpdateOne) SetSchedulable(v bool) *AccountUpdateOne {
|
||||
_u.mutation.SetSchedulable(v)
|
||||
@@ -1409,6 +1486,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if _u.mutation.LastUsedAtCleared() {
|
||||
_spec.ClearField(account.FieldLastUsedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(account.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(account.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.AutoPauseOnExpired(); ok {
|
||||
_spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Schedulable(); ok {
|
||||
_spec.SetField(account.FieldSchedulable, field.TypeBool, value)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,16 @@ type Group struct {
|
||||
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
// DefaultValidityDays holds the value of the "default_validity_days" field.
|
||||
DefaultValidityDays int `json:"default_validity_days,omitempty"`
|
||||
// ImagePrice1k holds the value of the "image_price_1k" field.
|
||||
ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
|
||||
// ImagePrice2k holds the value of the "image_price_2k" field.
|
||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -151,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case group.FieldIsExclusive:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd:
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays:
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -271,6 +281,40 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.DefaultValidityDays = int(value.Int64)
|
||||
}
|
||||
case group.FieldImagePrice1k:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImagePrice1k = new(float64)
|
||||
*_m.ImagePrice1k = value.Float64
|
||||
}
|
||||
case group.FieldImagePrice2k:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_price_2k", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImagePrice2k = new(float64)
|
||||
*_m.ImagePrice2k = value.Float64
|
||||
}
|
||||
case group.FieldImagePrice4k:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_price_4k", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImagePrice4k = new(float64)
|
||||
*_m.ImagePrice4k = value.Float64
|
||||
}
|
||||
case group.FieldClaudeCodeOnly:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ClaudeCodeOnly = value.Bool
|
||||
}
|
||||
case group.FieldFallbackGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.FallbackGroupID = new(int64)
|
||||
*_m.FallbackGroupID = value.Int64
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -398,6 +442,29 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_validity_days=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImagePrice1k; v != nil {
|
||||
builder.WriteString("image_price_1k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImagePrice2k; v != nil {
|
||||
builder.WriteString("image_price_2k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImagePrice4k; v != nil {
|
||||
builder.WriteString("image_price_4k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("claude_code_only=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.FallbackGroupID; v != nil {
|
||||
builder.WriteString("fallback_group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ const (
|
||||
FieldMonthlyLimitUsd = "monthly_limit_usd"
|
||||
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
|
||||
FieldDefaultValidityDays = "default_validity_days"
|
||||
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
|
||||
FieldImagePrice1k = "image_price_1k"
|
||||
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
|
||||
FieldImagePrice2k = "image_price_2k"
|
||||
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
||||
FieldImagePrice4k = "image_price_4k"
|
||||
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||
FieldClaudeCodeOnly = "claude_code_only"
|
||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||
FieldFallbackGroupID = "fallback_group_id"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -132,6 +142,11 @@ var Columns = []string{
|
||||
FieldWeeklyLimitUsd,
|
||||
FieldMonthlyLimitUsd,
|
||||
FieldDefaultValidityDays,
|
||||
FieldImagePrice1k,
|
||||
FieldImagePrice2k,
|
||||
FieldImagePrice4k,
|
||||
FieldClaudeCodeOnly,
|
||||
FieldFallbackGroupID,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -187,6 +202,8 @@ var (
|
||||
SubscriptionTypeValidator func(string) error
|
||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -267,6 +284,31 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImagePrice1k orders the results by the image_price_1k field.
|
||||
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImagePrice2k orders the results by the image_price_2k field.
|
||||
func ByImagePrice2k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice2k, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImagePrice4k orders the results by the image_price_4k field.
|
||||
func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByFallbackGroupID orders the results by the fallback_group_id field.
|
||||
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -125,6 +125,31 @@ func DefaultValidityDays(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
|
||||
func ImagePrice1k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2k applies equality check predicate on the "image_price_2k" field. It's identical to ImagePrice2kEQ.
|
||||
func ImagePrice2k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4k applies equality check predicate on the "image_price_4k" field. It's identical to ImagePrice4kEQ.
|
||||
func ImagePrice4k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
|
||||
func FallbackGroupID(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -830,6 +855,216 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kNEQ applies the NEQ predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kIn applies the In predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldImagePrice1k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice1kNotIn applies the NotIn predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldImagePrice1k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice1kGT applies the GT predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kGTE applies the GTE predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kLT applies the LT predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kLTE applies the LTE predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldImagePrice1k, v))
|
||||
}
|
||||
|
||||
// ImagePrice1kIsNil applies the IsNil predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldImagePrice1k))
|
||||
}
|
||||
|
||||
// ImagePrice1kNotNil applies the NotNil predicate on the "image_price_1k" field.
|
||||
func ImagePrice1kNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice1k))
|
||||
}
|
||||
|
||||
// ImagePrice2kEQ applies the EQ predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kNEQ applies the NEQ predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kIn applies the In predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldImagePrice2k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice2kNotIn applies the NotIn predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldImagePrice2k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice2kGT applies the GT predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kGTE applies the GTE predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kLT applies the LT predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kLTE applies the LTE predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldImagePrice2k, v))
|
||||
}
|
||||
|
||||
// ImagePrice2kIsNil applies the IsNil predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldImagePrice2k))
|
||||
}
|
||||
|
||||
// ImagePrice2kNotNil applies the NotNil predicate on the "image_price_2k" field.
|
||||
func ImagePrice2kNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice2k))
|
||||
}
|
||||
|
||||
// ImagePrice4kEQ applies the EQ predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kNEQ applies the NEQ predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kNEQ(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kIn applies the In predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldImagePrice4k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice4kNotIn applies the NotIn predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kNotIn(vs ...float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldImagePrice4k, vs...))
|
||||
}
|
||||
|
||||
// ImagePrice4kGT applies the GT predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kGT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kGTE applies the GTE predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kGTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kLT applies the LT predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kLT(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kLTE applies the LTE predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kLTE(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ImagePrice4kIsNil applies the IsNil predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldImagePrice4k))
|
||||
}
|
||||
|
||||
// ImagePrice4kNotNil applies the NotNil predicate on the "image_price_4k" field.
|
||||
func ImagePrice4kNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
|
||||
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
|
||||
}
|
||||
|
||||
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDGT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDGTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDLT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDLTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -216,6 +216,76 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
|
||||
_c.mutation.SetImagePrice1k(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImagePrice1k sets the "image_price_1k" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImagePrice1k(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImagePrice1k(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (_c *GroupCreate) SetImagePrice2k(v float64) *GroupCreate {
|
||||
_c.mutation.SetImagePrice2k(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImagePrice2k sets the "image_price_2k" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImagePrice2k(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImagePrice2k(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (_c *GroupCreate) SetImagePrice4k(v float64) *GroupCreate {
|
||||
_c.mutation.SetImagePrice4k(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImagePrice4k sets the "image_price_4k" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetImagePrice4k(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
|
||||
_c.mutation.SetFallbackGroupID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -381,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultValidityDays
|
||||
_c.mutation.SetDefaultValidityDays(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -433,6 +507,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -516,6 +593,26 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
_node.DefaultValidityDays = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
_node.ImagePrice1k = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImagePrice2k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice2k, field.TypeFloat64, value)
|
||||
_node.ImagePrice2k = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImagePrice4k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
_node.ImagePrice4k = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
_node.ClaudeCodeOnly = value
|
||||
}
|
||||
if value, ok := _c.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
_node.FallbackGroupID = &value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -888,6 +985,114 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImagePrice1k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImagePrice1k sets the "image_price_1k" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImagePrice1k() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImagePrice1k)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImagePrice1k adds v to the "image_price_1k" field.
|
||||
func (u *GroupUpsert) AddImagePrice1k(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldImagePrice1k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImagePrice1k clears the value of the "image_price_1k" field.
|
||||
func (u *GroupUpsert) ClearImagePrice1k() *GroupUpsert {
|
||||
u.SetNull(group.FieldImagePrice1k)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (u *GroupUpsert) SetImagePrice2k(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImagePrice2k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImagePrice2k sets the "image_price_2k" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImagePrice2k() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImagePrice2k)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImagePrice2k adds v to the "image_price_2k" field.
|
||||
func (u *GroupUpsert) AddImagePrice2k(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldImagePrice2k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImagePrice2k clears the value of the "image_price_2k" field.
|
||||
func (u *GroupUpsert) ClearImagePrice2k() *GroupUpsert {
|
||||
u.SetNull(group.FieldImagePrice2k)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (u *GroupUpsert) SetImagePrice4k(v float64) *GroupUpsert {
|
||||
u.Set(group.FieldImagePrice4k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImagePrice4k sets the "image_price_4k" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateImagePrice4k() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldImagePrice4k)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImagePrice4k adds v to the "image_price_4k" field.
|
||||
func (u *GroupUpsert) AddImagePrice4k(v float64) *GroupUpsert {
|
||||
u.Add(group.FieldImagePrice4k, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImagePrice4k clears the value of the "image_price_4k" field.
|
||||
func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
|
||||
u.SetNull(group.FieldImagePrice4k)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldClaudeCodeOnly, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldClaudeCodeOnly)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
|
||||
u.Set(group.FieldFallbackGroupID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldFallbackGroupID)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
|
||||
u.Add(group.FieldFallbackGroupID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
||||
u.SetNull(group.FieldFallbackGroupID)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1185,6 +1390,132 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice1k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice1k adds v to the "image_price_1k" field.
|
||||
func (u *GroupUpsertOne) AddImagePrice1k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice1k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice1k sets the "image_price_1k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImagePrice1k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice1k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice1k clears the value of the "image_price_1k" field.
|
||||
func (u *GroupUpsertOne) ClearImagePrice1k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice1k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (u *GroupUpsertOne) SetImagePrice2k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice2k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice2k adds v to the "image_price_2k" field.
|
||||
func (u *GroupUpsertOne) AddImagePrice2k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice2k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice2k sets the "image_price_2k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImagePrice2k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice2k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice2k clears the value of the "image_price_2k" field.
|
||||
func (u *GroupUpsertOne) ClearImagePrice2k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice2k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (u *GroupUpsertOne) SetImagePrice4k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice4k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice4k adds v to the "image_price_4k" field.
|
||||
func (u *GroupUpsertOne) AddImagePrice4k(v float64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice4k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice4k sets the "image_price_4k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateImagePrice4k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice4k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice4k clears the value of the "image_price_4k" field.
|
||||
func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice4k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetClaudeCodeOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateClaudeCodeOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1648,6 +1979,132 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice1k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice1k adds v to the "image_price_1k" field.
|
||||
func (u *GroupUpsertBulk) AddImagePrice1k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice1k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice1k sets the "image_price_1k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImagePrice1k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice1k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice1k clears the value of the "image_price_1k" field.
|
||||
func (u *GroupUpsertBulk) ClearImagePrice1k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice1k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (u *GroupUpsertBulk) SetImagePrice2k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice2k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice2k adds v to the "image_price_2k" field.
|
||||
func (u *GroupUpsertBulk) AddImagePrice2k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice2k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice2k sets the "image_price_2k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImagePrice2k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice2k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice2k clears the value of the "image_price_2k" field.
|
||||
func (u *GroupUpsertBulk) ClearImagePrice2k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice2k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (u *GroupUpsertBulk) SetImagePrice4k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetImagePrice4k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImagePrice4k adds v to the "image_price_4k" field.
|
||||
func (u *GroupUpsertBulk) AddImagePrice4k(v float64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddImagePrice4k(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImagePrice4k sets the "image_price_4k" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateImagePrice4k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateImagePrice4k()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImagePrice4k clears the value of the "image_price_4k" field.
|
||||
func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearImagePrice4k()
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetClaudeCodeOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateClaudeCodeOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -273,6 +273,128 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
_u.mutation.SetImagePrice1k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice1k sets the "image_price_1k" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImagePrice1k(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImagePrice1k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice1k adds value to the "image_price_1k" field.
|
||||
func (_u *GroupUpdate) AddImagePrice1k(v float64) *GroupUpdate {
|
||||
_u.mutation.AddImagePrice1k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice1k clears the value of the "image_price_1k" field.
|
||||
func (_u *GroupUpdate) ClearImagePrice1k() *GroupUpdate {
|
||||
_u.mutation.ClearImagePrice1k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (_u *GroupUpdate) SetImagePrice2k(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImagePrice2k()
|
||||
_u.mutation.SetImagePrice2k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice2k sets the "image_price_2k" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImagePrice2k(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImagePrice2k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice2k adds value to the "image_price_2k" field.
|
||||
func (_u *GroupUpdate) AddImagePrice2k(v float64) *GroupUpdate {
|
||||
_u.mutation.AddImagePrice2k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice2k clears the value of the "image_price_2k" field.
|
||||
func (_u *GroupUpdate) ClearImagePrice2k() *GroupUpdate {
|
||||
_u.mutation.ClearImagePrice2k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (_u *GroupUpdate) SetImagePrice4k(v float64) *GroupUpdate {
|
||||
_u.mutation.ResetImagePrice4k()
|
||||
_u.mutation.SetImagePrice4k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice4k sets the "image_price_4k" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableImagePrice4k(v *float64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetImagePrice4k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice4k adds value to the "image_price_4k" field.
|
||||
func (_u *GroupUpdate) AddImagePrice4k(v float64) *GroupUpdate {
|
||||
_u.mutation.AddImagePrice4k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice4k clears the value of the "image_price_4k" field.
|
||||
func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
|
||||
_u.mutation.ClearImagePrice4k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
|
||||
_u.mutation.ResetFallbackGroupID()
|
||||
_u.mutation.SetFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
|
||||
_u.mutation.AddFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
||||
_u.mutation.ClearFallbackGroupID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -642,6 +764,45 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice1k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice1kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice1k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice2k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice2k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice2k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice2k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice2kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice2k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice4k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice4k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1195,6 +1356,128 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice1k sets the "image_price_1k" field.
|
||||
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImagePrice1k()
|
||||
_u.mutation.SetImagePrice1k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice1k sets the "image_price_1k" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImagePrice1k(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImagePrice1k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice1k adds value to the "image_price_1k" field.
|
||||
func (_u *GroupUpdateOne) AddImagePrice1k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddImagePrice1k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice1k clears the value of the "image_price_1k" field.
|
||||
func (_u *GroupUpdateOne) ClearImagePrice1k() *GroupUpdateOne {
|
||||
_u.mutation.ClearImagePrice1k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice2k sets the "image_price_2k" field.
|
||||
func (_u *GroupUpdateOne) SetImagePrice2k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImagePrice2k()
|
||||
_u.mutation.SetImagePrice2k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice2k sets the "image_price_2k" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImagePrice2k(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImagePrice2k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice2k adds value to the "image_price_2k" field.
|
||||
func (_u *GroupUpdateOne) AddImagePrice2k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddImagePrice2k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice2k clears the value of the "image_price_2k" field.
|
||||
func (_u *GroupUpdateOne) ClearImagePrice2k() *GroupUpdateOne {
|
||||
_u.mutation.ClearImagePrice2k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImagePrice4k sets the "image_price_4k" field.
|
||||
func (_u *GroupUpdateOne) SetImagePrice4k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.ResetImagePrice4k()
|
||||
_u.mutation.SetImagePrice4k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImagePrice4k sets the "image_price_4k" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableImagePrice4k(v *float64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImagePrice4k(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImagePrice4k adds value to the "image_price_4k" field.
|
||||
func (_u *GroupUpdateOne) AddImagePrice4k(v float64) *GroupUpdateOne {
|
||||
_u.mutation.AddImagePrice4k(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImagePrice4k clears the value of the "image_price_4k" field.
|
||||
func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
|
||||
_u.mutation.ClearImagePrice4k()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
|
||||
_u.mutation.ResetFallbackGroupID()
|
||||
_u.mutation.SetFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
|
||||
_u.mutation.AddFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
||||
_u.mutation.ClearFallbackGroupID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1594,6 +1877,45 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
|
||||
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice1k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice1k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice1k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice1kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice1k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice2k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice2k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice2k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice2k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice2kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice2k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ImagePrice4k(); ok {
|
||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImagePrice4k(); ok {
|
||||
_spec.AddField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -80,6 +80,8 @@ var (
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "auto_pause_on_expired", Type: field.TypeBool, Default: true},
|
||||
{Name: "schedulable", Type: field.TypeBool, Default: true},
|
||||
{Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -97,7 +99,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
Columns: []*schema.Column{AccountsColumns[24]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -121,7 +123,7 @@ var (
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[22]},
|
||||
Columns: []*schema.Column{AccountsColumns[24]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
@@ -136,22 +138,22 @@ var (
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
Columns: []*schema.Column{AccountsColumns[17]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[16]},
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[17]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
@@ -216,6 +218,11 @@ var (
|
||||
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
|
||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
@@ -368,6 +375,9 @@ var (
|
||||
{Name: "stream", Type: field.TypeBool, Default: false},
|
||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "api_key_id", Type: field.TypeInt64},
|
||||
{Name: "account_id", Type: field.TypeInt64},
|
||||
@@ -383,31 +393,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[21]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[22]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[23]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -416,32 +426,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[21]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[22]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[23]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[20]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[23]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -456,12 +466,12 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[20]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[21], UsageLogsColumns[20]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -181,12 +181,16 @@ func init() {
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[12].Descriptor()
|
||||
accountDescSchedulable := accountFields[14].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[18].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[20].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
@@ -266,6 +270,10 @@ func init() {
|
||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
proxyMixin := schema.Proxy{}.Mixin()
|
||||
proxyMixinHooks1 := proxyMixin[1].Hooks()
|
||||
proxy.Hooks[0] = proxyMixinHooks1[0]
|
||||
@@ -521,8 +529,20 @@ func init() {
|
||||
usagelogDescStream := usagelogFields[21].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[25].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[26].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[24].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[27].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -118,6 +118,16 @@ func (Account) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
// expires_at: 账户过期时间(可为空)
|
||||
field.Time("expires_at").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("Account expiration time (NULL means no expiration).").
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
// auto_pause_on_expired: 过期后自动暂停调度
|
||||
field.Bool("auto_pause_on_expired").
|
||||
Default(true).
|
||||
Comment("Auto pause scheduling when account expires."),
|
||||
|
||||
// ========== 调度和速率限制相关字段 ==========
|
||||
// 这些字段在 migrations/005_schema_parity.sql 中添加
|
||||
|
||||
@@ -72,6 +72,29 @@ func (Group) Fields() []ent.Field {
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Int("default_validity_days").
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Float("image_price_2k").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
field.Float("image_price_4k").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
Comment("是否仅允许 Claude Code 客户端"),
|
||||
field.Int64("fallback_group_id").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +110,8 @@ func (Group) Edges() []ent.Edge {
|
||||
edge.From("allowed_users", User.Type).
|
||||
Ref("allowed_groups").
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -96,6 +96,18 @@ func (UsageLog) Fields() []ent.Field {
|
||||
field.Int("first_token_ms").
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("user_agent").
|
||||
MaxLen(512).
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
|
||||
field.Int("image_count").
|
||||
Default(0),
|
||||
field.String("image_size").
|
||||
MaxLen(10).
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// 时间戳(只有 created_at,日志不可修改)
|
||||
field.Time("created_at").
|
||||
|
||||
@@ -70,6 +70,12 @@ type UsageLog struct {
|
||||
DurationMs *int `json:"duration_ms,omitempty"`
|
||||
// FirstTokenMs holds the value of the "first_token_ms" field.
|
||||
FirstTokenMs *int `json:"first_token_ms,omitempty"`
|
||||
// UserAgent holds the value of the "user_agent" field.
|
||||
UserAgent *string `json:"user_agent,omitempty"`
|
||||
// ImageCount holds the value of the "image_count" field.
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
// ImageSize holds the value of the "image_size" field.
|
||||
ImageSize *string `json:"image_size,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
@@ -159,9 +165,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs:
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -334,6 +340,26 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.FirstTokenMs = new(int)
|
||||
*_m.FirstTokenMs = int(value.Int64)
|
||||
}
|
||||
case usagelog.FieldUserAgent:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field user_agent", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UserAgent = new(string)
|
||||
*_m.UserAgent = value.String
|
||||
}
|
||||
case usagelog.FieldImageCount:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_count", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageCount = int(value.Int64)
|
||||
}
|
||||
case usagelog.FieldImageSize:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_size", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageSize = new(string)
|
||||
*_m.ImageSize = value.String
|
||||
}
|
||||
case usagelog.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
@@ -481,6 +507,19 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.UserAgent; v != nil {
|
||||
builder.WriteString("user_agent=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_count=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageSize; v != nil {
|
||||
builder.WriteString("image_size=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteByte(')')
|
||||
|
||||
@@ -62,6 +62,12 @@ const (
|
||||
FieldDurationMs = "duration_ms"
|
||||
// FieldFirstTokenMs holds the string denoting the first_token_ms field in the database.
|
||||
FieldFirstTokenMs = "first_token_ms"
|
||||
// FieldUserAgent holds the string denoting the user_agent field in the database.
|
||||
FieldUserAgent = "user_agent"
|
||||
// FieldImageCount holds the string denoting the image_count field in the database.
|
||||
FieldImageCount = "image_count"
|
||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||
FieldImageSize = "image_size"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
@@ -140,6 +146,9 @@ var Columns = []string{
|
||||
FieldStream,
|
||||
FieldDurationMs,
|
||||
FieldFirstTokenMs,
|
||||
FieldUserAgent,
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldCreatedAt,
|
||||
}
|
||||
|
||||
@@ -188,6 +197,12 @@ var (
|
||||
DefaultBillingType int8
|
||||
// DefaultStream holds the default value on creation for the "stream" field.
|
||||
DefaultStream bool
|
||||
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
UserAgentValidator func(string) error
|
||||
// DefaultImageCount holds the default value on creation for the "image_count" field.
|
||||
DefaultImageCount int
|
||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
ImageSizeValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
)
|
||||
@@ -320,6 +335,21 @@ func ByFirstTokenMs(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldFirstTokenMs, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserAgent orders the results by the user_agent field.
|
||||
func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageCount orders the results by the image_count field.
|
||||
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageCount, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageSize orders the results by the image_size field.
|
||||
func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
|
||||
@@ -175,6 +175,21 @@ func FirstTokenMs(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldFirstTokenMs, v))
|
||||
}
|
||||
|
||||
// UserAgent applies equality check predicate on the "user_agent" field. It's identical to UserAgentEQ.
|
||||
func UserAgent(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
|
||||
func ImageCount(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageSize applies equality check predicate on the "image_size" field. It's identical to ImageSizeEQ.
|
||||
func ImageSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1100,6 +1115,196 @@ func FirstTokenMsNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldFirstTokenMs))
|
||||
}
|
||||
|
||||
// UserAgentEQ applies the EQ predicate on the "user_agent" field.
|
||||
func UserAgentEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentNEQ applies the NEQ predicate on the "user_agent" field.
|
||||
func UserAgentNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentIn applies the In predicate on the "user_agent" field.
|
||||
func UserAgentIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldUserAgent, vs...))
|
||||
}
|
||||
|
||||
// UserAgentNotIn applies the NotIn predicate on the "user_agent" field.
|
||||
func UserAgentNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldUserAgent, vs...))
|
||||
}
|
||||
|
||||
// UserAgentGT applies the GT predicate on the "user_agent" field.
|
||||
func UserAgentGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentGTE applies the GTE predicate on the "user_agent" field.
|
||||
func UserAgentGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentLT applies the LT predicate on the "user_agent" field.
|
||||
func UserAgentLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentLTE applies the LTE predicate on the "user_agent" field.
|
||||
func UserAgentLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentContains applies the Contains predicate on the "user_agent" field.
|
||||
func UserAgentContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentHasPrefix applies the HasPrefix predicate on the "user_agent" field.
|
||||
func UserAgentHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentHasSuffix applies the HasSuffix predicate on the "user_agent" field.
|
||||
func UserAgentHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentIsNil applies the IsNil predicate on the "user_agent" field.
|
||||
func UserAgentIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldUserAgent))
|
||||
}
|
||||
|
||||
// UserAgentNotNil applies the NotNil predicate on the "user_agent" field.
|
||||
func UserAgentNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldUserAgent))
|
||||
}
|
||||
|
||||
// UserAgentEqualFold applies the EqualFold predicate on the "user_agent" field.
|
||||
func UserAgentEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// UserAgentContainsFold applies the ContainsFold predicate on the "user_agent" field.
|
||||
func UserAgentContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// ImageCountEQ applies the EQ predicate on the "image_count" field.
|
||||
func ImageCountEQ(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageCountNEQ applies the NEQ predicate on the "image_count" field.
|
||||
func ImageCountNEQ(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageCountIn applies the In predicate on the "image_count" field.
|
||||
func ImageCountIn(vs ...int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageCount, vs...))
|
||||
}
|
||||
|
||||
// ImageCountNotIn applies the NotIn predicate on the "image_count" field.
|
||||
func ImageCountNotIn(vs ...int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageCount, vs...))
|
||||
}
|
||||
|
||||
// ImageCountGT applies the GT predicate on the "image_count" field.
|
||||
func ImageCountGT(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageCountGTE applies the GTE predicate on the "image_count" field.
|
||||
func ImageCountGTE(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageCountLT applies the LT predicate on the "image_count" field.
|
||||
func ImageCountLT(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageCountLTE applies the LTE predicate on the "image_count" field.
|
||||
func ImageCountLTE(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageCount, v))
|
||||
}
|
||||
|
||||
// ImageSizeEQ applies the EQ predicate on the "image_size" field.
|
||||
func ImageSizeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeNEQ applies the NEQ predicate on the "image_size" field.
|
||||
func ImageSizeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeIn applies the In predicate on the "image_size" field.
|
||||
func ImageSizeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageSize, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeNotIn applies the NotIn predicate on the "image_size" field.
|
||||
func ImageSizeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageSize, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeGT applies the GT predicate on the "image_size" field.
|
||||
func ImageSizeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeGTE applies the GTE predicate on the "image_size" field.
|
||||
func ImageSizeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeLT applies the LT predicate on the "image_size" field.
|
||||
func ImageSizeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeLTE applies the LTE predicate on the "image_size" field.
|
||||
func ImageSizeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeContains applies the Contains predicate on the "image_size" field.
|
||||
func ImageSizeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeHasPrefix applies the HasPrefix predicate on the "image_size" field.
|
||||
func ImageSizeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeHasSuffix applies the HasSuffix predicate on the "image_size" field.
|
||||
func ImageSizeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeIsNil applies the IsNil predicate on the "image_size" field.
|
||||
func ImageSizeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageSize))
|
||||
}
|
||||
|
||||
// ImageSizeNotNil applies the NotNil predicate on the "image_size" field.
|
||||
func ImageSizeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageSize))
|
||||
}
|
||||
|
||||
// ImageSizeEqualFold applies the EqualFold predicate on the "image_size" field.
|
||||
func ImageSizeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeContainsFold applies the ContainsFold predicate on the "image_size" field.
|
||||
func ImageSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
|
||||
|
||||
@@ -323,6 +323,48 @@ func (_c *UsageLogCreate) SetNillableFirstTokenMs(v *int) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_c *UsageLogCreate) SetUserAgent(v string) *UsageLogCreate {
|
||||
_c.mutation.SetUserAgent(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetUserAgent(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
|
||||
_c.mutation.SetImageCount(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageCount sets the "image_count" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageCount(v *int) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageCount(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (_c *UsageLogCreate) SetImageSize(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageSize(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageSize sets the "image_size" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageSize(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -457,6 +499,10 @@ func (_c *UsageLogCreate) defaults() {
|
||||
v := usagelog.DefaultStream
|
||||
_c.mutation.SetStream(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ImageCount(); !ok {
|
||||
v := usagelog.DefaultImageCount
|
||||
_c.mutation.SetImageCount(v)
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
v := usagelog.DefaultCreatedAt()
|
||||
_c.mutation.SetCreatedAt(v)
|
||||
@@ -535,6 +581,19 @@ func (_c *UsageLogCreate) check() error {
|
||||
if _, ok := _c.mutation.Stream(); !ok {
|
||||
return &ValidationError{Name: "stream", err: errors.New(`ent: missing required field "UsageLog.stream"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageCount(); !ok {
|
||||
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageSize(); ok {
|
||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
|
||||
}
|
||||
@@ -650,6 +709,18 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value)
|
||||
_node.FirstTokenMs = &value
|
||||
}
|
||||
if value, ok := _c.mutation.UserAgent(); ok {
|
||||
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
||||
_node.UserAgent = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
_node.ImageCount = value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
_node.ImageSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.CreatedAt(); ok {
|
||||
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
|
||||
_node.CreatedAt = value
|
||||
@@ -1199,6 +1270,60 @@ func (u *UsageLogUpsert) ClearFirstTokenMs() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (u *UsageLogUpsert) SetUserAgent(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldUserAgent, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUserAgent sets the "user_agent" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateUserAgent() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldUserAgent)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearUserAgent clears the value of the "user_agent" field.
|
||||
func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldUserAgent)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageCount, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageCount sets the "image_count" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageCount() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageCount)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddImageCount adds v to the "image_count" field.
|
||||
func (u *UsageLogUpsert) AddImageCount(v int) *UsageLogUpsert {
|
||||
u.Add(usagelog.FieldImageCount, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (u *UsageLogUpsert) SetImageSize(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageSize, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageSize sets the "image_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageSize() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageSize clears the value of the "image_size" field.
|
||||
func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1720,6 +1845,69 @@ func (u *UsageLogUpsertOne) ClearFirstTokenMs() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (u *UsageLogUpsertOne) SetUserAgent(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUserAgent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAgent sets the "user_agent" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateUserAgent() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUserAgent()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUserAgent clears the value of the "user_agent" field.
|
||||
func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUserAgent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageCount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageCount adds v to the "image_count" field.
|
||||
func (u *UsageLogUpsertOne) AddImageCount(v int) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddImageCount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageCount sets the "image_count" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageCount() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageCount()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (u *UsageLogUpsertOne) SetImageSize(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSize sets the "image_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSize clears the value of the "image_size" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSize()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2407,6 +2595,69 @@ func (u *UsageLogUpsertBulk) ClearFirstTokenMs() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (u *UsageLogUpsertBulk) SetUserAgent(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetUserAgent(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAgent sets the "user_agent" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateUserAgent() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateUserAgent()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearUserAgent clears the value of the "user_agent" field.
|
||||
func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearUserAgent()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageCount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddImageCount adds v to the "image_count" field.
|
||||
func (u *UsageLogUpsertBulk) AddImageCount(v int) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddImageCount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageCount sets the "image_count" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageCount() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageCount()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageSize(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSize sets the "image_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSize clears the value of the "image_size" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSize()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -504,6 +504,67 @@ func (_u *UsageLogUpdate) ClearFirstTokenMs() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_u *UsageLogUpdate) SetUserAgent(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetUserAgent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableUserAgent(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetUserAgent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUserAgent clears the value of the "user_agent" field.
|
||||
func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
|
||||
_u.mutation.ClearUserAgent()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
|
||||
_u.mutation.ResetImageCount()
|
||||
_u.mutation.SetImageCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageCount sets the "image_count" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageCount(v *int) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageCount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageCount adds value to the "image_count" field.
|
||||
func (_u *UsageLogUpdate) AddImageCount(v int) *UsageLogUpdate {
|
||||
_u.mutation.AddImageCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (_u *UsageLogUpdate) SetImageSize(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSize sets the "image_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageSize(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSize clears the value of the "image_size" field.
|
||||
func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -603,6 +664,16 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSize(); ok {
|
||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@@ -738,6 +809,24 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.FirstTokenMsCleared() {
|
||||
_spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.UserAgent(); ok {
|
||||
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UserAgentCleared() {
|
||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageCount(); ok {
|
||||
_spec.AddField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -1375,6 +1464,67 @@ func (_u *UsageLogUpdateOne) ClearFirstTokenMs() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_u *UsageLogUpdateOne) SetUserAgent(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetUserAgent(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableUserAgent(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUserAgent(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUserAgent clears the value of the "user_agent" field.
|
||||
func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearUserAgent()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetImageCount()
|
||||
_u.mutation.SetImageCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageCount sets the "image_count" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageCount(v *int) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageCount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddImageCount adds value to the "image_count" field.
|
||||
func (_u *UsageLogUpdateOne) AddImageCount(v int) *UsageLogUpdateOne {
|
||||
_u.mutation.AddImageCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSize sets the "image_size" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageSize(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSize sets the "image_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageSize(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSize clears the value of the "image_size" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -1487,6 +1637,16 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UserAgent(); ok {
|
||||
if err := usagelog.UserAgentValidator(v); err != nil {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSize(); ok {
|
||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@@ -1639,6 +1799,24 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.FirstTokenMsCleared() {
|
||||
_spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.UserAgent(); ok {
|
||||
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.UserAgentCleared() {
|
||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedImageCount(); ok {
|
||||
_spec.AddField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.11
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/imroc/req/v3 v3.57.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/spf13/viper v1.18.2
|
||||
@@ -20,16 +18,16 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/term v0.37.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/net v0.48.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.38.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
@@ -64,7 +62,6 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
@@ -74,10 +71,8 @@ require (
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.1 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
@@ -105,8 +100,8 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
@@ -141,16 +136,12 @@ require (
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/mod v0.29.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorm.io/datatypes v1.2.7 // indirect
|
||||
gorm.io/driver/mysql v1.5.6 // indirect
|
||||
gorm.io/gorm v1.30.0 // indirect
|
||||
)
|
||||
|
||||
@@ -4,8 +4,6 @@ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
entgo.io/ent v0.14.5 h1:Rj2WOYJtCkWyFo6a+5wB3EfBRP0rnx1fMk6gGA0UUe4=
|
||||
entgo.io/ent v0.14.5/go.mod h1:zTzLmWtPvGpmSwtkaayM2cm5m819NdM7z7tYPq3vN0U=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
@@ -96,15 +94,12 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
|
||||
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
|
||||
github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
|
||||
github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -126,8 +121,8 @@ github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZY
|
||||
github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE=
|
||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
|
||||
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -138,14 +133,10 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
@@ -219,10 +210,10 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI10=
|
||||
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
@@ -335,16 +326,16 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -354,16 +345,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
|
||||
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
|
||||
@@ -386,13 +377,6 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/datatypes v1.2.7 h1:ww9GAhF1aGXZY3EB3cJPJ7//JiuQo7DlQA7NNlVaTdk=
|
||||
gorm.io/datatypes v1.2.7/go.mod h1:M2iO+6S3hhi4nAyYe444Pcb0dcIiOMJ7QHaUXxyiNZY=
|
||||
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
|
||||
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +18,7 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
@@ -51,6 +52,15 @@ type Config struct {
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
// UpdateConfig 在线更新相关配置
|
||||
type UpdateConfig struct {
|
||||
// ProxyURL 用于访问 GitHub 的代理地址
|
||||
// 支持 http/https/socks5/socks5h 协议
|
||||
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
|
||||
ProxyURL string `mapstructure:"proxy_url"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
@@ -147,7 +157,7 @@ type CSPConfig struct {
|
||||
}
|
||||
|
||||
type ProxyProbeConfig struct {
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
|
||||
}
|
||||
|
||||
type BillingConfig struct {
|
||||
@@ -338,8 +348,19 @@ func NormalizeRunMode(value string) string {
|
||||
func Load() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
|
||||
// Add config paths in priority order
|
||||
// 1. DATA_DIR environment variable (highest priority)
|
||||
if dataDir := os.Getenv("DATA_DIR"); dataDir != "" {
|
||||
viper.AddConfigPath(dataDir)
|
||||
}
|
||||
// 2. Docker data directory
|
||||
viper.AddConfigPath("/app/data")
|
||||
// 3. Current directory
|
||||
viper.AddConfigPath(".")
|
||||
// 4. Config subdirectory
|
||||
viper.AddConfigPath("./config")
|
||||
// 5. System config directory
|
||||
viper.AddConfigPath("/etc/sub2api")
|
||||
|
||||
// 环境变量支持
|
||||
@@ -372,13 +393,13 @@ func Load() (*Config, error) {
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
|
||||
if cfg.JWT.Secret == "" {
|
||||
secret, err := generateJWTSecret(64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate jwt secret error: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
@@ -392,7 +413,7 @@ func Load() (*Config, error) {
|
||||
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||||
}
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
}
|
||||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||||
@@ -436,8 +457,8 @@ func setDefaults() {
|
||||
"raw.githubusercontent.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", false)
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
||||
viper.SetDefault("security.response_headers.enabled", false)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
@@ -546,20 +567,13 @@ func setDefaults() {
|
||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||
viper.SetDefault("gemini.oauth.scopes", "")
|
||||
viper.SetDefault("gemini.quota.policy", "")
|
||||
|
||||
// Update - 在线更新配置
|
||||
// 代理地址为空表示直连 GitHub(适用于海外服务器)
|
||||
viper.SetDefault("update.proxy_url", "")
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Mode == "release" {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required in release mode")
|
||||
}
|
||||
if len(c.JWT.Secret) < 32 {
|
||||
return fmt.Errorf("jwt.secret must be at least 32 characters")
|
||||
}
|
||||
if isWeakJWTSecret(c.JWT.Secret) {
|
||||
return fmt.Errorf("jwt.secret is too weak")
|
||||
}
|
||||
}
|
||||
if c.JWT.ExpireHour <= 0 {
|
||||
return fmt.Errorf("jwt.expire_hour must be positive")
|
||||
}
|
||||
|
||||
@@ -80,8 +80,11 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
if cfg.Security.URLAllowlist.Enabled {
|
||||
t.Fatalf("URLAllowlist.Enabled = true, want false")
|
||||
}
|
||||
if cfg.Security.URLAllowlist.AllowInsecureHTTP {
|
||||
t.Fatalf("URLAllowlist.AllowInsecureHTTP = true, want false")
|
||||
if !cfg.Security.URLAllowlist.AllowInsecureHTTP {
|
||||
t.Fatalf("URLAllowlist.AllowInsecureHTTP = false, want true")
|
||||
}
|
||||
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
||||
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
||||
}
|
||||
if cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
|
||||
@@ -34,15 +34,16 @@ func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
|
||||
|
||||
// AccountHandler handles admin account management
|
||||
type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
antigravityOAuthService *service.AntigravityOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
@@ -51,6 +52,7 @@ func NewAccountHandler(
|
||||
oauthService *service.OAuthService,
|
||||
openaiOAuthService *service.OpenAIOAuthService,
|
||||
geminiOAuthService *service.GeminiOAuthService,
|
||||
antigravityOAuthService *service.AntigravityOAuthService,
|
||||
rateLimitService *service.RateLimitService,
|
||||
accountUsageService *service.AccountUsageService,
|
||||
accountTestService *service.AccountTestService,
|
||||
@@ -58,15 +60,16 @@ func NewAccountHandler(
|
||||
crsSyncService *service.CRSSyncService,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +85,8 @@ type CreateAccountRequest struct {
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
@@ -98,6 +103,8 @@ type UpdateAccountRequest struct {
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
@@ -201,6 +208,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -258,6 +267,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -420,6 +431,19 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
|
||||
@@ -26,31 +26,33 @@ func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardH
|
||||
}
|
||||
|
||||
// parseTimeRange parses start_date, end_date query parameters
|
||||
// Uses user's timezone if provided, otherwise falls back to server timezone
|
||||
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
|
||||
@@ -33,6 +33,12 @@ type CreateGroupRequest struct {
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -47,6 +53,12 @@ type UpdateGroupRequest struct {
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -139,6 +151,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -174,6 +191,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -52,15 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
|
||||
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Proxy, 0, len(proxies))
|
||||
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyFromService(&proxies[i]))
|
||||
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -102,8 +102,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -112,7 +113,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -143,7 +144,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
@@ -151,8 +152,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// Stats handles getting usage statistics with filters
|
||||
// GET /api/v1/admin/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
// Parse filters - same as List endpoint
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
@@ -171,8 +172,50 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
accountID = id
|
||||
}
|
||||
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
groupID = id
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone")
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
var startTime, endTime time.Time
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
@@ -180,12 +223,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -195,39 +238,31 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
// Build filters and call GetStatsWithFilters
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
return
|
||||
}
|
||||
|
||||
// Get global stats
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
// Package dto provides data transfer objects for HTTP handlers.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func UserFromServiceShallow(u *service.User) *User {
|
||||
if u == nil {
|
||||
@@ -78,6 +82,11 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
@@ -117,6 +126,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
ExpiresAt: timeToUnixSeconds(a.ExpiresAt),
|
||||
AutoPauseOnExpired: a.AutoPauseOnExpired,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
@@ -154,6 +165,14 @@ func AccountFromService(a *service.Account) *Account {
|
||||
return out
|
||||
}
|
||||
|
||||
func timeToUnixSeconds(value *time.Time) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
ts := value.Unix()
|
||||
return &ts
|
||||
}
|
||||
|
||||
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
|
||||
if ag == nil {
|
||||
return nil
|
||||
@@ -217,7 +236,21 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
}
|
||||
}
|
||||
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
|
||||
// Only includes ID and Name - no sensitive fields like Credentials, Proxy, etc.
|
||||
func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &AccountSummary{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -247,15 +280,33 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
ImageSize: l.ImageSize,
|
||||
UserAgent: l.UserAgent,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Account: account,
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details - users should not see account information.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil)
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only).
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
|
||||
@@ -47,6 +47,15 @@ type Group struct {
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
@@ -55,21 +64,23 @@ type Group struct {
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
@@ -169,15 +180,29 @@ type UsageLog struct {
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
// AccountSummary is a minimal account info for usage log display.
|
||||
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
|
||||
type AccountSummary struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type Setting struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
|
||||
@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
@@ -108,6 +111,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
@@ -226,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -267,7 +273,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -276,10 +282,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -353,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -394,7 +401,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -403,10 +410,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -678,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@@ -13,6 +14,26 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
// 解析请求体为 map
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
|
||||
// 验证是否为 Claude Code 客户端
|
||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
|
||||
// 更新 request context
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
@@ -83,19 +104,33 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
|
||||
|
||||
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
||||
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
||||
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
|
||||
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
||||
if releaseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
var once sync.Once
|
||||
wrapped := func() {
|
||||
once.Do(releaseFunc)
|
||||
quit := make(chan struct{})
|
||||
|
||||
release := func() {
|
||||
once.Do(func() {
|
||||
releaseFunc()
|
||||
close(quit) // 通知监听 goroutine 退出
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
wrapped()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 取消时释放资源
|
||||
release()
|
||||
case <-quit:
|
||||
// 正常释放已完成,goroutine 退出
|
||||
return
|
||||
}
|
||||
}()
|
||||
return wrapped
|
||||
|
||||
return release
|
||||
}
|
||||
|
||||
// IncrementWaitCount increments the wait count for a user
|
||||
|
||||
141
backend/internal/handler/gateway_helper_test.go
Normal file
141
backend/internal/handler/gateway_helper_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestWrapReleaseOnDone_NoGoroutineLeak 验证 wrapReleaseOnDone 修复后不会泄露 goroutine
|
||||
func TestWrapReleaseOnDone_NoGoroutineLeak(t *testing.T) {
|
||||
// 记录测试开始时的 goroutine 数量
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 正常释放
|
||||
release()
|
||||
|
||||
// 等待足够时间确保 goroutine 退出
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
|
||||
// 强制 GC,清理已退出的 goroutine
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证 goroutine 数量没有增加(允许±2的误差,考虑到测试框架本身可能创建的 goroutine)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+2 {
|
||||
t.Errorf("goroutine leak detected: initial=%d, final=%d, leaked=%d",
|
||||
initialGoroutines, finalGoroutines, finalGoroutines-initialGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_ContextCancellation 验证 context 取消时也能正确释放
|
||||
func TestWrapReleaseOnDone_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var releaseCount int32
|
||||
_ = wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 取消 context,应该触发释放
|
||||
cancel()
|
||||
|
||||
// 等待释放完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证释放被调用
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce 验证多次调用 release 只释放一次
|
||||
func TestWrapReleaseOnDone_MultipleCallsOnlyReleaseOnce(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 调用多次
|
||||
release()
|
||||
release()
|
||||
release()
|
||||
|
||||
// 等待执行完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_NilReleaseFunc 验证 nil releaseFunc 不会 panic
|
||||
func TestWrapReleaseOnDone_NilReleaseFunc(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
release := wrapReleaseOnDone(ctx, nil)
|
||||
|
||||
if release != nil {
|
||||
t.Error("expected nil release function when releaseFunc is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapReleaseOnDone_ConcurrentCalls 验证并发调用的安全性
|
||||
func TestWrapReleaseOnDone_ConcurrentCalls(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var releaseCount int32
|
||||
release := wrapReleaseOnDone(ctx, func() {
|
||||
atomic.AddInt32(&releaseCount, 1)
|
||||
})
|
||||
|
||||
// 并发调用 release
|
||||
const numGoroutines = 10
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go release()
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证只释放一次
|
||||
if count := atomic.LoadInt32(&releaseCount); count != 1 {
|
||||
t.Errorf("expected release count to be 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWrapReleaseOnDone 性能基准测试
|
||||
func BenchmarkWrapReleaseOnDone(b *testing.B) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
release := wrapReleaseOnDone(ctx, func() {})
|
||||
release()
|
||||
}
|
||||
}
|
||||
@@ -164,6 +164,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||
|
||||
@@ -200,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
@@ -259,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -300,7 +307,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -309,10 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
@@ -251,10 +251,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
}(result, account, userAgent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取时间范围参数
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
var startTime, endTime time.Time
|
||||
|
||||
// 优先使用 start_date 和 end_date 参数
|
||||
@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if startDateStr != "" && endDateStr != "" {
|
||||
// 使用自定义日期范围
|
||||
var err error
|
||||
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
case "today":
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
case "week":
|
||||
startTime = now.AddDate(0, 0, -7)
|
||||
case "month":
|
||||
startTime = now.AddDate(0, -1, 0)
|
||||
default:
|
||||
startTime = timezone.StartOfDay(now)
|
||||
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
||||
}
|
||||
endTime = now
|
||||
}
|
||||
@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
|
||||
// Uses user's timezone if provided, otherwise falls back to server timezone
|
||||
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
userTZ := c.Query("timezone") // Get user's timezone from request
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
startDate := c.Query("start_date")
|
||||
endDate := c.Query("end_date")
|
||||
|
||||
var startTime, endTime time.Time
|
||||
|
||||
if startDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
|
||||
startTime = t
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
} else {
|
||||
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
|
||||
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
|
||||
}
|
||||
|
||||
if endDate != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
|
||||
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
|
||||
endTime = t.Add(24 * time.Hour) // Include the end date
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
} else {
|
||||
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
|
||||
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
}
|
||||
|
||||
return startTime, endTime
|
||||
|
||||
@@ -5,27 +5,66 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
|
||||
// resolveHost 从 URL 解析 host
|
||||
func resolveHost(urlStr string) string {
|
||||
parsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Host
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action)
|
||||
isStream := action == "streamGenerateContent"
|
||||
if isStream {
|
||||
apiURL += "?alt=sse"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 基础 Headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
// Accept Header 根据请求类型设置
|
||||
if isStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 显式设置 Host Header
|
||||
if host := resolveHost(apiURL); host != "" {
|
||||
req.Host = host
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
// 向后兼容:仅使用默认 BaseURL
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body)
|
||||
}
|
||||
|
||||
// TokenResponse Google OAuth token 响应
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -132,6 +171,38 @@ func NewClient(proxyURL string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查超时错误
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查连接错误(DNS 失败、连接拒绝)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 URL 错误
|
||||
var urlErr *url.Error
|
||||
return errors.As(err, &urlErr)
|
||||
}
|
||||
|
||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
@@ -240,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
|
||||
}
|
||||
|
||||
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
@@ -249,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := BaseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
@@ -307,6 +404,7 @@ type FetchAvailableModelsResponse struct {
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
@@ -314,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
@@ -67,6 +67,13 @@ type GeminiGenerationConfig struct {
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
}
|
||||
|
||||
// GeminiThinkingConfig Gemini thinking 配置
|
||||
|
||||
@@ -32,16 +32,79 @@ const (
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// API 端点
|
||||
BaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
// User-Agent
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
// User-Agent(模拟官方客户端)
|
||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// URL 可用性 TTL(不可用 URL 的恢复时间)
|
||||
URLAvailabilityTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
||||
// fallback 顺序: sandbox → daily → prod
|
||||
var BaseURLs = []string{
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
||||
"https://cloudcode-pa.googleapis.com", // prod
|
||||
}
|
||||
|
||||
// BaseURL 默认 URL(保持向后兼容)
|
||||
var BaseURL = BaseURLs[0]
|
||||
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
||||
type URLAvailability struct {
|
||||
mu sync.RWMutex
|
||||
unavailable map[string]time.Time // URL -> 恢复时间
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||
var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL)
|
||||
|
||||
// NewURLAvailability 创建 URL 可用性管理器
|
||||
func NewURLAvailability(ttl time.Duration) *URLAvailability {
|
||||
return &URLAvailability{
|
||||
unavailable: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkUnavailable 标记 URL 临时不可用
|
||||
func (u *URLAvailability) MarkUnavailable(url string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||
}
|
||||
|
||||
// IsAvailable 检查 URL 是否可用
|
||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
return time.Now().After(expiry)
|
||||
}
|
||||
|
||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
result := make([]string, 0, len(BaseURLs))
|
||||
for _, url := range BaseURLs {
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// OAuthSession 保存 OAuth 授权流程的临时状态
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
|
||||
@@ -1,17 +1,46 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
sessionRandMutex sync.Mutex
|
||||
)
|
||||
|
||||
// generateStableSessionID 基于用户消息内容生成稳定的 session ID
|
||||
func generateStableSessionID(contents []GeminiContent) string {
|
||||
// 查找第一个 user 消息的文本
|
||||
for _, content := range contents {
|
||||
if content.Role == "user" && len(content.Parts) > 0 {
|
||||
if text := content.Parts[0].Text; text != "" {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 回退:生成随机 session ID
|
||||
sessionRandMutex.Lock()
|
||||
n := sessionRand.Int63n(9_000_000_000_000_000_000)
|
||||
sessionRandMutex.Unlock()
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
type TransformOptions struct {
|
||||
EnableIdentityPatch bool
|
||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||
@@ -67,8 +96,15 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
|
||||
// 5. 构建内部请求
|
||||
innerRequest := GeminiRequest{
|
||||
Contents: contents,
|
||||
SafetySettings: DefaultSafetySettings,
|
||||
Contents: contents,
|
||||
// 总是设置 toolConfig,与官方客户端一致
|
||||
ToolConfig: &GeminiToolConfig{
|
||||
FunctionCallingConfig: &GeminiFunctionCallingConfig{
|
||||
Mode: "VALIDATED",
|
||||
},
|
||||
},
|
||||
// 总是生成 sessionId,基于用户消息内容
|
||||
SessionID: generateStableSessionID(contents),
|
||||
}
|
||||
|
||||
if systemInstruction != nil {
|
||||
@@ -79,14 +115,9 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
innerRequest.Tools = tools
|
||||
innerRequest.ToolConfig = &GeminiToolConfig{
|
||||
FunctionCallingConfig: &GeminiFunctionCallingConfig{
|
||||
Mode: "VALIDATED",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 如果提供了 metadata.user_id,复用为 sessionId
|
||||
// 如果提供了 metadata.user_id,优先使用
|
||||
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
|
||||
innerRequest.SessionID = claudeReq.Metadata.UserID
|
||||
}
|
||||
@@ -95,7 +126,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
v1Req := V1InternalRequest{
|
||||
Project: projectID,
|
||||
RequestID: "agent-" + uuid.New().String(),
|
||||
UserAgent: "sub2api",
|
||||
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
||||
RequestType: "agent",
|
||||
Model: mappedModel,
|
||||
Request: innerRequest,
|
||||
@@ -104,37 +135,42 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
return json.Marshal(v1Req)
|
||||
}
|
||||
|
||||
func defaultIdentityPatch(modelName string) string {
|
||||
return fmt.Sprintf(
|
||||
"--- [IDENTITY_PATCH] ---\n"+
|
||||
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
|
||||
"You are currently providing services as the native %s model via a standard API proxy.\n"+
|
||||
"Always use the 'claude' command for terminal tasks if relevant.\n"+
|
||||
"--- [SYSTEM_PROMPT_BEGIN] ---\n",
|
||||
modelName,
|
||||
)
|
||||
// antigravityIdentity Antigravity identity 提示词
|
||||
const antigravityIdentity = `<identity>
|
||||
You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.
|
||||
You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.
|
||||
The USER will send you requests, which you must always prioritize addressing. Along with each USER request, we will attach additional metadata about their current state, such as what files they have open and where their cursor is.
|
||||
This information may or may not be relevant to the coding task, it is up for you to decide.
|
||||
</identity>
|
||||
<communication_style>
|
||||
- **Proactiveness**. As an agent, you are allowed to be proactive, but only in the course of completing the user's task. For example, if the user asks you to add a new component, you can edit the code, verify build and test statuses, and take any other obvious follow-up actions, such as performing additional research. However, avoid surprising the user. For example, if the user asks HOW to approach something, you should answer their question and instead of jumping into editing a file.</communication_style>`
|
||||
|
||||
func defaultIdentityPatch(_ string) string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// GetDefaultIdentityPatch 返回默认的 Antigravity 身份提示词
|
||||
func GetDefaultIdentityPatch() string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
|
||||
// 可选注入身份防护指令(身份补丁)
|
||||
if opts.EnableIdentityPatch {
|
||||
identityPatch := strings.TrimSpace(opts.IdentityPatch)
|
||||
if identityPatch == "" {
|
||||
identityPatch = defaultIdentityPatch(modelName)
|
||||
}
|
||||
parts = append(parts, GeminiPart{Text: identityPatch})
|
||||
}
|
||||
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
||||
userHasAntigravityIdentity := false
|
||||
var userSystemParts []GeminiPart
|
||||
|
||||
// 解析 system prompt
|
||||
if len(system) > 0 {
|
||||
// 尝试解析为字符串
|
||||
var sysStr string
|
||||
if err := json.Unmarshal(system, &sysStr); err == nil {
|
||||
if strings.TrimSpace(sysStr) != "" {
|
||||
parts = append(parts, GeminiPart{Text: sysStr})
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 尝试解析为数组
|
||||
@@ -142,17 +178,28 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
||||
for _, block := range sysBlocks {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, GeminiPart{Text: block.Text})
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// identity patch 模式下,用分隔符包裹 system prompt,便于上游识别/调试;关闭时尽量保持原始 system prompt。
|
||||
if opts.EnableIdentityPatch && len(parts) > 0 {
|
||||
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
|
||||
// 仅在用户未提供 Antigravity identity 时注入
|
||||
if opts.EnableIdentityPatch && !userHasAntigravityIdentity {
|
||||
identityPatch := strings.TrimSpace(opts.IdentityPatch)
|
||||
if identityPatch == "" {
|
||||
identityPatch = defaultIdentityPatch(modelName)
|
||||
}
|
||||
parts = append(parts, GeminiPart{Text: identityPatch})
|
||||
}
|
||||
|
||||
// 添加用户的 system prompt
|
||||
parts = append(parts, userSystemParts...)
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ type Key string
|
||||
const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||
)
|
||||
|
||||
@@ -27,10 +27,9 @@ const (
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// DefaultScopes for Google One (personal Google accounts with Gemini access)
|
||||
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
|
||||
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
|
||||
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
|
||||
// DefaultGoogleOneScopes (DEPRECATED, no longer used)
|
||||
// Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
|
||||
// This constant is kept for backward compatibility but is not actively used.
|
||||
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
|
||||
@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
case "google_one":
|
||||
// Google One uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultGoogleOneScopes
|
||||
}
|
||||
// Google One always uses built-in Gemini CLI client (same as code_assist)
|
||||
// Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
default:
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
|
||||
@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with custom client",
|
||||
name: "Google One always uses built-in client (even if custom credentials passed)",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultGoogleOneScopes,
|
||||
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -40,7 +39,7 @@ type Options struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
|
||||
Timeout time.Duration // 请求总超时时间
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
|
||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
|
||||
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
|
||||
@@ -113,7 +112,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
}
|
||||
|
||||
if opts.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
// 安全要求:禁止跳过证书验证,避免中间人攻击。
|
||||
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
|
||||
}
|
||||
|
||||
proxyURL := strings.TrimSpace(opts.ProxyURL)
|
||||
|
||||
@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
|
||||
func ParseInLocation(layout, value string) (time.Time, error) {
|
||||
return time.ParseInLocation(layout, value, Location())
|
||||
}
|
||||
|
||||
// ParseInUserLocation parses a time string in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func ParseInUserLocation(layout, value, userTZ string) (time.Time, error) {
|
||||
loc := Location() // default to server timezone
|
||||
if userTZ != "" {
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
loc = userLoc
|
||||
}
|
||||
}
|
||||
return time.ParseInLocation(layout, value, loc)
|
||||
}
|
||||
|
||||
// NowInUserLocation returns the current time in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func NowInUserLocation(userTZ string) time.Time {
|
||||
if userTZ == "" {
|
||||
return Now()
|
||||
}
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
return time.Now().In(userLoc)
|
||||
}
|
||||
return Now()
|
||||
}
|
||||
|
||||
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
|
||||
// If userTZ is empty or invalid, falls back to the configured server timezone.
|
||||
func StartOfDayInUserLocation(t time.Time, userTZ string) time.Time {
|
||||
loc := Location()
|
||||
if userTZ != "" {
|
||||
if userLoc, err := time.LoadLocation(userTZ); err == nil {
|
||||
loc = userLoc
|
||||
}
|
||||
}
|
||||
t = t.In(loc)
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
SetPriority(account.Priority).
|
||||
SetStatus(account.Status).
|
||||
SetErrorMessage(account.ErrorMessage).
|
||||
SetSchedulable(account.Schedulable)
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
if account.LastUsedAt != nil {
|
||||
builder.SetLastUsedAt(*account.LastUsedAt)
|
||||
}
|
||||
if account.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*account.ExpiresAt)
|
||||
}
|
||||
if account.RateLimitedAt != nil {
|
||||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||||
}
|
||||
@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
SetPriority(account.Priority).
|
||||
SetStatus(account.Status).
|
||||
SetErrorMessage(account.ErrorMessage).
|
||||
SetSchedulable(account.Schedulable)
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
} else {
|
||||
builder.ClearLastUsedAt()
|
||||
}
|
||||
if account.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*account.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
if account.RateLimitedAt != nil {
|
||||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||||
} else {
|
||||
@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -727,6 +740,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET schedulable = FALSE,
|
||||
updated_at = NOW()
|
||||
WHERE deleted_at IS NULL
|
||||
AND schedulable = TRUE
|
||||
AND auto_pause_on_expired = TRUE
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at <= $1
|
||||
`, now)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
@@ -861,6 +895,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
|
||||
preds = append(preds,
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
)
|
||||
@@ -971,6 +1006,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
|
||||
})
|
||||
}
|
||||
|
||||
func notExpiredPredicate(now time.Time) dbpredicate.Account {
|
||||
return dbaccount.Or(
|
||||
dbaccount.ExpiresAtIsNil(),
|
||||
dbaccount.ExpiresAtGT(now),
|
||||
dbaccount.AutoPauseOnExpiredEQ(false),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
@@ -1086,6 +1129,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
|
||||
@@ -321,7 +321,12 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
|
||||
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
||||
// 格式: sticky_session:{groupID}:{sessionHash}
|
||||
func buildSessionKey(groupID int64, sessionHash string) string {
|
||||
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
groupID := int64(1)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
sessionKey := stickySessionPrefix + sessionID
|
||||
groupID := int64(1)
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
@@ -14,23 +14,33 @@ import (
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
allowPrivateHosts bool
|
||||
httpClient *http.Client
|
||||
downloadHTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
allowPrivate := false
|
||||
// NewGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
|
||||
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// 下载客户端需要更长的超时时间
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
|
||||
return &githubReleaseClient{
|
||||
httpClient: sharedClient,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
httpClient: sharedClient,
|
||||
downloadHTTPClient: downloadClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
return err
|
||||
}
|
||||
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: c.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
resp, err := downloadClient.Do(req)
|
||||
// 使用预配置的下载客户端(已包含代理配置)
|
||||
resp, err := c.downloadHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
||||
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{},
|
||||
allowPrivateHosts: true,
|
||||
httpClient: &http.Client{},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
allowPrivateHosts: true,
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
@@ -43,7 +43,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays)
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
@@ -69,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
builder := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
@@ -80,8 +85,20 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
Save(ctx)
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
if cfg != nil {
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
}
|
||||
// NewPricingRemoteClient 创建定价数据远程客户端
|
||||
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
|
||||
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: validateResolvedIP,
|
||||
AllowPrivateHosts: allowPrivate,
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient(&config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}).(*pricingRemoteClient)
|
||||
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
|
||||
@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return outProxies, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
|
||||
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -109,6 +109,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
stream,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
@@ -116,7 +119,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -126,6 +129,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
subscriptionID := nullInt64(log.SubscriptionID)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -157,6 +162,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.Stream,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
@@ -1382,6 +1390,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStatsWithFilters gets usage statistics with optional filters
|
||||
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
|
||||
conditions := make([]string, 0, 9)
|
||||
args := make([]any, 0, 9)
|
||||
|
||||
if filters.UserID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.APIKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.APIKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
args = append(args, filters.AccountID)
|
||||
}
|
||||
if filters.GroupID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||||
args = append(args, filters.GroupID)
|
||||
}
|
||||
if filters.Model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
%s
|
||||
`, buildWhere(conditions))
|
||||
|
||||
stats := &UsageStats{}
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
args,
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// AccountUsageHistory represents daily usage history for an account
|
||||
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||||
|
||||
@@ -1789,6 +1872,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
stream bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -1818,6 +1904,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&stream,
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -1844,6 +1933,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
RateMultiplier: rateMultiplier,
|
||||
BillingType: int8(billingType),
|
||||
Stream: stream,
|
||||
ImageCount: imageCount,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
@@ -1866,6 +1956,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
value := int(firstTokenMs.Int64)
|
||||
log.FirstTokenMs = &value
|
||||
}
|
||||
if userAgent.Valid {
|
||||
log.UserAgent = &userAgent.String
|
||||
}
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
@@ -1938,6 +2034,13 @@ func nullInt(v *int) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||||
}
|
||||
|
||||
func nullString(v *string) sql.NullString {
|
||||
if v == nil || *v == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: *v, Valid: true}
|
||||
}
|
||||
|
||||
func setToSlice(set map[int64]struct{}) []int64 {
|
||||
out := make([]int64, 0, len(set))
|
||||
for id := range set {
|
||||
|
||||
@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
|
||||
}
|
||||
|
||||
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
|
||||
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
|
||||
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
|
||||
}
|
||||
|
||||
// ProvidePricingRemoteClient 创建定价数据远程客户端
|
||||
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
|
||||
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
|
||||
return NewPricingRemoteClient(cfg.Update.ProxyURL)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
@@ -53,8 +65,8 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
NewPricingRemoteClient,
|
||||
NewGitHubReleaseClient,
|
||||
ProvidePricingRemoteClient,
|
||||
ProvideGitHubReleaseClient,
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
|
||||
@@ -241,7 +241,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"stream": true,
|
||||
"duration_ms": 100,
|
||||
"first_token_ms": 50,
|
||||
"created_at": "2025-01-02T03:04:05Z"
|
||||
"image_count": 0,
|
||||
"image_size": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"user_agent": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
@@ -1063,6 +1066,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubSettingRepo struct {
|
||||
all map[string]string
|
||||
}
|
||||
|
||||
@@ -9,21 +9,23 @@ import (
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
ExpiresAt *time.Time
|
||||
AutoPauseOnExpired bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
Schedulable bool
|
||||
|
||||
@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
71
backend/internal/service/account_expiry_service.go
Normal file
71
backend/internal/service/account_expiry_service.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
|
||||
type AccountExpiryService struct {
|
||||
accountRepo AccountRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
|
||||
return &AccountExpiryService{
|
||||
accountRepo: accountRepo,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Start() {
|
||||
if s == nil || s.accountRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
|
||||
if err != nil {
|
||||
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
if updated > 0 {
|
||||
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ type AccountRepository interface {
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]Account, error)
|
||||
@@ -71,29 +72,33 @@ type AccountBulkUpdate struct {
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
@@ -134,6 +139,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: StatusActive,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
@@ -224,6 +235,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
if req.Status != nil {
|
||||
account.Status = *req.Status
|
||||
}
|
||||
if req.ExpiresAt != nil {
|
||||
account.ExpiresAt = req.ExpiresAt
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
|
||||
@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
panic("unexpected AutoPauseExpiredAccounts call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
panic("unexpected BindGroups call")
|
||||
}
|
||||
|
||||
@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Check for completion
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content
|
||||
// Extract content first (before checking completion)
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for completion after extracting content
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ type UsageLogRepository interface {
|
||||
// Admin usage listing/stats
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
|
||||
@@ -47,6 +47,7 @@ type AdminService interface {
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
|
||||
GetAllProxies(ctx context.Context) ([]Proxy, error)
|
||||
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
|
||||
GetProxy(ctx context.Context, id int64) (*Proxy, error)
|
||||
@@ -98,6 +99,12 @@ type CreateGroupInput struct {
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -111,19 +118,27 @@ type UpdateGroupInput struct {
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
@@ -140,6 +155,8 @@ type UpdateAccountInput struct {
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
|
||||
}
|
||||
|
||||
@@ -498,6 +515,18 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
|
||||
// 图片价格:负数表示清除(使用默认价格),0 保留(表示免费)
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -509,6 +538,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -524,6 +558,37 @@ func normalizeLimit(limit *float64) *float64 {
|
||||
return limit
|
||||
}
|
||||
|
||||
// normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费)
|
||||
func normalizePrice(price *float64) *float64 {
|
||||
if price == nil || *price < 0 {
|
||||
return nil
|
||||
}
|
||||
return price
|
||||
}
|
||||
|
||||
// validateFallbackGroup 校验降级分组的有效性
|
||||
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||
// fallbackGroupID: 降级分组 ID
|
||||
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
|
||||
// 不能将自己设置为降级分组
|
||||
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -563,6 +628,33 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
}
|
||||
if input.ImagePrice2K != nil {
|
||||
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
|
||||
}
|
||||
if input.ImagePrice4K != nil {
|
||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
|
||||
}
|
||||
if input.FallbackGroupID != nil {
|
||||
// 校验降级分组
|
||||
if *input.FallbackGroupID > 0 {
|
||||
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group.FallbackGroupID = input.FallbackGroupID
|
||||
} else {
|
||||
// 传入 0 或负数表示清除降级分组
|
||||
group.FallbackGroupID = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -666,6 +758,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
}
|
||||
if input.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -721,6 +822,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
if *input.ExpiresAt <= 0 {
|
||||
account.ExpiresAt = nil
|
||||
} else {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
}
|
||||
}
|
||||
if input.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
@@ -892,6 +1004,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
||||
return s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
@@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy
|
||||
panic("unexpected ListActiveWithAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFiltersAndAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
panic("unexpected ExistsByHostPortAuth call")
|
||||
}
|
||||
|
||||
197
backend/internal/service/admin_service_group_test.go
Normal file
197
backend/internal/service/admin_service_group_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
|
||||
type groupRepoStubForAdmin struct {
|
||||
created *Group // 记录 Create 调用的参数
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了正确的字段
|
||||
require.NotNil(t, repo.created)
|
||||
require.NotNil(t, repo.created.ImagePrice1K)
|
||||
require.NotNil(t, repo.created.ImagePrice2K)
|
||||
require.NotNil(t, repo.created.ImagePrice4K)
|
||||
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
|
||||
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
// ImagePrice 字段全部为 nil
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 ImagePrice 字段为 nil
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.ImagePrice1K)
|
||||
require.Nil(t, repo.created.ImagePrice2K)
|
||||
require.Nil(t, repo.created.ImagePrice4K)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
|
||||
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.12
|
||||
price2K := 0.18
|
||||
price4K := 0.36
|
||||
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了更新后的字段
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.NotNil(t, repo.updated.ImagePrice4K)
|
||||
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
|
||||
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
oldPrice2K := 0.15
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
// 只更新 1K 价格
|
||||
price1K := 0.10
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
123
backend/internal/service/antigravity_image_test.go
Normal file
123
backend/internal/service/antigravity_image_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
|
||||
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
|
||||
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
|
||||
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
|
||||
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
|
||||
require.False(t, isImageGenerationModel("claude-3-opus"))
|
||||
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
|
||||
require.False(t, isImageGenerationModel("gpt-4o"))
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
||||
// 验证不会误匹配包含关键词的自定义模型名
|
||||
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
|
||||
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
||||
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
|
||||
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
|
||||
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
|
||||
func TestExtractImageSize_ValidSizes(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 1K
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
// 2K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 4K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
|
||||
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
|
||||
func TestExtractImageSize_Default(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 无 generationConfig
|
||||
body := []byte(`{"contents":[]}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 generationConfig 但无 imageConfig
|
||||
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 imageConfig 但无 imageSize
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
|
||||
func TestExtractImageSize_InvalidJSON(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`not valid json`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"broken":`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
|
||||
func TestExtractImageSize_EmptySize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 空格
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
|
||||
func TestExtractImageSize_InvalidSize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
@@ -20,12 +20,16 @@ var (
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -78,14 +82,18 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
// 这是一个配置错误,不应该允许绕过验证
|
||||
if s.emailService == nil {
|
||||
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if verifyCode == "" {
|
||||
return "", nil, ErrEmailVerifyRequired
|
||||
}
|
||||
// 验证邮箱验证码
|
||||
if s.emailService != nil {
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,7 +317,20 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
if len(tokenString) > maxTokenLength {
|
||||
return nil, ErrTokenTooLarge
|
||||
}
|
||||
|
||||
// 使用解析器并限制可接受的签名算法,防止算法混淆。
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{
|
||||
jwt.SigningMethodHS256.Name,
|
||||
jwt.SigningMethodHS384.Name,
|
||||
jwt.SigningMethodHS512.Name,
|
||||
}))
|
||||
|
||||
// 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。
|
||||
token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
@@ -319,6 +340,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
|
||||
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok {
|
||||
return claims, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
|
||||
@@ -113,13 +113,27 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{} // 配置 emailService
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
@@ -180,3 +194,63 @@ func TestAuthService_Register_Success(t *testing.T) {
|
||||
require.Len(t, repo.created, 1)
|
||||
require.True(t, user.CheckPassword("password"))
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证有效 token
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
|
||||
// 模拟过期 token(通过创建一个过期很久的 token)
|
||||
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1 // 恢复
|
||||
|
||||
// 验证过期 token 应返回 claims 和 ErrTokenExpired
|
||||
claims, err = service.ValidateToken(expiredToken)
|
||||
require.ErrorIs(t, err, ErrTokenExpired)
|
||||
require.NotNil(t, claims, "claims should not be nil when token is expired")
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
require.Equal(t, "test@test.com", claims.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1
|
||||
|
||||
// RefreshToken 使用过期 token 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
newToken, err := service.RefreshToken(context.Background(), expiredToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, newToken)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -295,3 +295,88 @@ func (s *BillingService) ForceUpdatePricing() error {
|
||||
}
|
||||
return fmt.Errorf("pricing service not initialized")
|
||||
}
|
||||
|
||||
// ImagePriceConfig 图片计费配置
|
||||
type ImagePriceConfig struct {
|
||||
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
|
||||
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
// imageCount: 生成的图片数量
|
||||
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
|
||||
// rateMultiplier: 费率倍数
|
||||
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
// 获取单价
|
||||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||||
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
if groupConfig.Price1K != nil {
|
||||
return *groupConfig.Price1K
|
||||
}
|
||||
case "2K":
|
||||
if groupConfig.Price2K != nil {
|
||||
return *groupConfig.Price2K
|
||||
}
|
||||
case "4K":
|
||||
if groupConfig.Price4K != nil {
|
||||
return *groupConfig.Price4K
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 LiteLLM 默认价格
|
||||
return s.getDefaultImagePrice(model, imageSize)
|
||||
}
|
||||
|
||||
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
|
||||
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
|
||||
basePrice := 0.0
|
||||
|
||||
// 从 PricingService 获取 output_cost_per_image
|
||||
if s.pricingService != nil {
|
||||
pricing := s.pricingService.GetModelPricing(model)
|
||||
if pricing != nil && pricing.OutputCostPerImage > 0 {
|
||||
basePrice = pricing.OutputCostPerImage
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
|
||||
if basePrice <= 0 {
|
||||
basePrice = 0.134
|
||||
}
|
||||
|
||||
// 4K 尺寸翻倍
|
||||
if imageSize == "4K" {
|
||||
return basePrice * 2
|
||||
}
|
||||
|
||||
return basePrice
|
||||
}
|
||||
|
||||
149
backend/internal/service/billing_service_image_test.go
Normal file
149
backend/internal/service/billing_service_image_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateImageCost_DefaultPricing 测试无分组配置时使用默认价格
|
||||
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
|
||||
|
||||
// 2K 尺寸,默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001)
|
||||
|
||||
// 多张图片
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
|
||||
func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
Price2K: &price2K,
|
||||
Price4K: &price4K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 2, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.15, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
|
||||
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 4K 尺寸,默认价格翻倍 $0.134 * 2 = $0.268
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier 测试费率倍数
|
||||
func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.536, cost.ActualCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
|
||||
func TestCalculateImageCost_ZeroCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_NegativeCount 测试 imageCount=-1
|
||||
func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", -1, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
func TestGetImageUnitPrice_GroupPriorityOverDefault(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price2K := 0.20
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price2K: &price2K,
|
||||
}
|
||||
|
||||
// 分组配置了 2K 价格,应该使用分组价格而不是默认的 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_PartialGroupConfig 测试分组部分配置时回退默认
|
||||
func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 只配置 1K 价格
|
||||
price1K := 0.10
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 回退默认价格 $0.134
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 回退默认价格 $0.268 (翻倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetDefaultImagePrice_FallbackHardcoded 测试 PricingService 无数据时使用硬编码默认值
|
||||
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil
|
||||
|
||||
// 1K 和 2K 使用相同的默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
}
|
||||
265
backend/internal/service/claude_code_validator.go
Normal file
265
backend/internal/service/claude_code_validator.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
|
||||
// 完全学习自 claude-relay-service 项目的验证逻辑
|
||||
type ClaudeCodeValidator struct{}
|
||||
|
||||
var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||
|
||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||
systemPromptThreshold = 0.5
|
||||
)
|
||||
|
||||
// Claude Code 官方 System Prompt 模板
|
||||
// 从 claude-relay-service/src/utils/contents.js 提取
|
||||
var claudeCodeSystemPrompts = []string{
|
||||
// claudeOtherSystemPrompt1 - Primary
|
||||
"You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPrompt3 - Agent SDK
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
|
||||
|
||||
// claudeOtherSystemPrompt4 - Compact Agent SDK
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
|
||||
|
||||
// exploreAgentSystemPrompt
|
||||
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
|
||||
"You are a helpful AI assistant tasked with summarizing conversations.",
|
||||
|
||||
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
|
||||
"You are an interactive CLI tool that helps users",
|
||||
}
|
||||
|
||||
// NewClaudeCodeValidator 创建验证器实例
|
||||
func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
return &ClaudeCodeValidator{}
|
||||
}
|
||||
|
||||
// Validate 验证请求是否来自 Claude Code CLI
|
||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 对于 messages 路径,进行严格验证:
|
||||
// - System prompt 相似度检查
|
||||
// - X-App header 检查
|
||||
// - anthropic-beta header 检查
|
||||
// - anthropic-version header 检查
|
||||
// - metadata.user_id 格式验证
|
||||
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
|
||||
// Step 1: User-Agent 检查
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !claudeCodeUAPattern.MatchString(ua) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Step 2: 非 messages 路径,只要 UA 匹配就通过
|
||||
path := r.URL.Path
|
||||
if !strings.Contains(path, "messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Step 3: messages 路径,进行严格验证
|
||||
|
||||
// 3.1 检查 system prompt 相似度
|
||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.2 检查必需的 headers(值不为空即可)
|
||||
xApp := r.Header.Get("X-App")
|
||||
if xApp == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicBeta := r.Header.Get("anthropic-beta")
|
||||
if anthropicBeta == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicVersion := r.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.3 验证 metadata.user_id
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
metadata, ok := body["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !userIDPattern.MatchString(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
|
||||
// 使用字符串相似度匹配(Dice coefficient)
|
||||
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 model 字段
|
||||
if _, ok := body["model"].(string); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取 system 字段
|
||||
systemEntries, ok := body["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每个 system entry
|
||||
for _, entry := range systemEntries {
|
||||
entryMap, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok := entryMap["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算与所有模板的最佳相似度
|
||||
bestScore := v.bestSimilarityScore(text)
|
||||
if bestScore >= systemPromptThreshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
|
||||
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
|
||||
normalizedText := normalizePrompt(text)
|
||||
bestScore := 0.0
|
||||
|
||||
for _, template := range claudeCodeSystemPrompts {
|
||||
normalizedTemplate := normalizePrompt(template)
|
||||
score := diceCoefficient(normalizedText, normalizedTemplate)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
return bestScore
|
||||
}
|
||||
|
||||
// normalizePrompt 标准化提示词文本(去除多余空白)
|
||||
func normalizePrompt(text string) string {
|
||||
// 将所有空白字符替换为单个空格,并去除首尾空白
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient)
|
||||
// 这是 string-similarity 库使用的算法
|
||||
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
|
||||
func diceCoefficient(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
if len(a) < 2 || len(b) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 生成 bigrams
|
||||
bigramsA := getBigrams(a)
|
||||
bigramsB := getBigrams(b)
|
||||
|
||||
if len(bigramsA) == 0 || len(bigramsB) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算交集大小
|
||||
intersection := 0
|
||||
for bigram, countA := range bigramsA {
|
||||
if countB, exists := bigramsB[bigram]; exists {
|
||||
if countA < countB {
|
||||
intersection += countA
|
||||
} else {
|
||||
intersection += countB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算总 bigram 数量
|
||||
totalA := 0
|
||||
for _, count := range bigramsA {
|
||||
totalA += count
|
||||
}
|
||||
totalB := 0
|
||||
for _, count := range bigramsB {
|
||||
totalB += count
|
||||
}
|
||||
|
||||
return float64(2*intersection) / float64(totalA+totalB)
|
||||
}
|
||||
|
||||
// getBigrams 获取字符串的所有 bigrams(相邻字符对)
|
||||
func getBigrams(s string) map[string]int {
|
||||
bigrams := make(map[string]int)
|
||||
runes := []rune(strings.ToLower(s))
|
||||
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
bigram := string(runes[i : i+2])
|
||||
bigrams[bigram]++
|
||||
}
|
||||
|
||||
return bigrams
|
||||
}
|
||||
|
||||
// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景)
|
||||
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
|
||||
return claudeCodeUAPattern.MatchString(ua)
|
||||
}
|
||||
|
||||
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
|
||||
// 只要存在匹配的系统提示词就返回 true(用于宽松检测)
|
||||
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
return v.hasClaudeCodeSystemPrompt(body)
|
||||
}
|
||||
|
||||
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
|
||||
func IsClaudeCodeClient(ctx context.Context) bool {
|
||||
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
|
||||
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||
}
|
||||
@@ -140,6 +140,8 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
|
||||
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
// 强制 TLS 1.2+,避免协议降级导致的弱加密风险。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
@@ -311,7 +313,11 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
tlsConfig := &tls.Config{ServerName: config.Host}
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: config.Host,
|
||||
// 与发送逻辑一致,显式要求 TLS 1.2+。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
|
||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, err
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
@@ -163,14 +166,14 @@ type mockGatewayCacheForPlatform struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
@@ -178,7 +181,7 @@ func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
)
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
@@ -43,8 +44,21 @@ var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
|
||||
claudeCodePromptPrefixes = []string{
|
||||
"You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...)
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体
|
||||
"You are a file search specialist for Claude Code", // Explore Agent 版
|
||||
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
||||
}
|
||||
)
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -69,9 +83,17 @@ var allowedHeaders = map[string]bool{
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
func derefGroupID(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
@@ -98,12 +120,17 @@ type ClaudeUsage struct {
|
||||
|
||||
// ForwardResult 转发结果
|
||||
type ForwardResult struct {
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
@@ -209,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
@@ -340,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
platform = group.Platform
|
||||
|
||||
// 检查 Claude Code 客户端限制
|
||||
if group.ClaudeCodeOnly {
|
||||
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||
if !isClaudeCode {
|
||||
// 非 Claude Code 客户端,检查是否有降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
// 使用降级分组重新调度
|
||||
fallbackGroupID := *group.FallbackGroupID
|
||||
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
// 无降级分组,拒绝访问
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
@@ -351,17 +393,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
|
||||
if hasForcePlatform && groupID != nil {
|
||||
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err == nil {
|
||||
return account, nil
|
||||
}
|
||||
// 分组中找不到,回退查询全部该平台账户
|
||||
groupID = nil
|
||||
}
|
||||
|
||||
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
@@ -370,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
@@ -436,15 +476,16 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulable() &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
@@ -498,7 +539,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -548,7 +589,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
@@ -576,7 +617,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -584,7 +625,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
@@ -611,6 +652,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||
// - 有降级分组:返回降级分组的 ID
|
||||
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||
if groupID == nil {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 分组启用了 Claude Code 限制
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 非 Claude Code 客户端,检查降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
return group.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
@@ -656,9 +733,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
if err == nil && len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
@@ -681,6 +756,23 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
|
||||
return account.Platform == platform
|
||||
}
|
||||
|
||||
// isAccountInGroup checks if the account belongs to the specified group.
|
||||
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
|
||||
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
|
||||
if groupID == nil {
|
||||
return true // 无分组限制
|
||||
}
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
for _, ag := range account.AccountGroups {
|
||||
if ag.GroupID == *groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
@@ -715,13 +807,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
|
||||
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
@@ -788,7 +880,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -804,14 +896,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
@@ -880,7 +972,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -1009,15 +1101,15 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
}
|
||||
|
||||
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||
// 支持 string 和 []any 两种格式
|
||||
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
|
||||
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v == claudeCodeSystemPrompt
|
||||
return hasClaudeCodePrefix(v)
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -1026,6 +1118,16 @@ func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头
|
||||
func hasClaudeCodePrefix(text string) bool {
|
||||
for _, prefix := range claudeCodePromptPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||
// 处理 null、字符串、数组三种格式
|
||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
@@ -1069,6 +1171,124 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
||||
func enforceCacheControlLimit(body []byte) []byte {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// 计算当前 cache_control 块数量
|
||||
count := countCacheControlBlocks(data)
|
||||
if count <= maxCacheControlBlocks {
|
||||
return body
|
||||
}
|
||||
|
||||
// 超限:优先从 messages 中移除,再从 system 中移除
|
||||
for count > maxCacheControlBlocks {
|
||||
if removeCacheControlFromMessages(data) {
|
||||
count--
|
||||
continue
|
||||
}
|
||||
if removeCacheControlFromSystem(data) {
|
||||
count--
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
result, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||
func countCacheControlBlocks(data map[string]any) int {
|
||||
count := 0
|
||||
|
||||
// 统计 system 中的块
|
||||
if system, ok := data["system"].([]any); ok {
|
||||
for _, item := range system {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 统计 messages 中的块
|
||||
if messages, ok := data["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if content, ok := msgMap["content"].([]any); ok {
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||
messages, ok := data["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||
system, ok := data["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||
for i := len(system) - 1; i >= 0; i-- {
|
||||
if m, ok := system[i].(map[string]any); ok {
|
||||
if _, has := m["cache_control"]; has {
|
||||
delete(m, "cache_control")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -1089,6 +1309,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -1312,6 +1535,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1324,6 +1548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
@@ -1332,12 +1557,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1692,8 +1918,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
|
||||
// streamingResult 流式响应结果
|
||||
type streamingResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
@@ -1789,14 +2016,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
// 上游完成,返回结果
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
log.Printf("Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
@@ -1807,38 +2047,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
line := ev.line
|
||||
if line == "event: error" {
|
||||
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
var data string
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
}
|
||||
|
||||
// 转发行
|
||||
// 写入客户端(统一处理 data 行和非 data 行)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
clientDisconnected = true
|
||||
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
// 无论客户端是否断开,都解析 usage(仅对 data 行)
|
||||
if data != "" {
|
||||
if firstTokenMs == nil && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -1846,6 +2088,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
@@ -1999,6 +2246,7 @@ type RecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -2009,25 +2257,40 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 计算费用
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
|
||||
// 获取费率倍数
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
// 使用默认费用继续
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
var cost *CostBreakdown
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
@@ -2039,6 +2302,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
@@ -2060,9 +2327,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
|
||||
@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -217,7 +217,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
|
||||
@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, error
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
@@ -187,14 +190,14 @@ type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
@@ -202,7 +205,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
|
||||
// Google One accounts use cloudaicompanion API, which requires a project_id.
|
||||
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
|
||||
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one uses custom client when configured and redirects to localhost",
|
||||
name: "google_one always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
|
||||
wantScope: geminicli.DefaultGoogleOneScopes,
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
|
||||
@@ -17,6 +17,15 @@ type Group struct {
|
||||
MonthlyLimitUSD *float64
|
||||
DefaultValidityDays int
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -47,3 +56,19 @@ func (g *Group) HasWeeklyLimit() bool {
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
|
||||
// GetImagePrice 根据 image_size 返回对应的图片生成价格
|
||||
// 如果分组未配置价格,返回 nil(调用方应使用默认值)
|
||||
func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
return g.ImagePrice1K
|
||||
case "2K":
|
||||
return g.ImagePrice2K
|
||||
case "4K":
|
||||
return g.ImagePrice4K
|
||||
default:
|
||||
// 未知尺寸默认按 2K 计费
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user