mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
198 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
6826149a8f | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
c9debc50b1 | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
b43ee62947 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
c28f691f32 |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -42,6 +42,6 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.11
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
20
Dockerfile
20
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
38
README.md
38
README.md
@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
|
||||
| Project | Description | Features |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
@@ -150,14 +160,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +532,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
38
README_CN.md
38
README_CN.md
@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
|
||||
| 项目 | 说明 | 功能 |
|
||||
|------|------|------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -154,14 +164,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -588,6 +598,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Privacy client factory for OpenAI training opt-out
|
||||
providePrivacyClientFactory,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -144,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -199,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -226,11 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -245,6 +251,10 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -279,6 +289,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -414,6 +425,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -7,7 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -66,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
|
||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
|
||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1402,7 +1404,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -84,10 +85,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
@@ -111,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -95,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -114,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -626,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -656,10 +659,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
@@ -715,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -770,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -792,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -844,20 +851,54 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -913,14 +954,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
@@ -1516,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -461,9 +466,60 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
|
||||
@@ -335,6 +335,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||
type BatchSetGroupRateMultipliersRequest struct {
|
||||
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRateMultipliersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
Extra: nil,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays <= 0 {
|
||||
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
validityDays int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-" + tc.name,
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": tc.validityDays,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
@@ -1348,6 +1357,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -216,6 +216,38 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Monthly bool `json:"monthly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -98,6 +98,19 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -125,9 +138,9 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
@@ -251,15 +264,48 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -475,6 +521,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -81,6 +81,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +114,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -161,6 +165,26 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -193,8 +196,22 @@ type Account struct {
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -315,6 +332,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -652,6 +654,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -729,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -972,33 +983,45 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
@@ -138,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
@@ -155,6 +157,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
284
backend/internal/handler/openai_chat_completions.go
Normal file
284
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -212,6 +213,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -259,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -288,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -330,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -538,9 +562,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -602,6 +642,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -614,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@@ -641,6 +677,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -674,17 +729,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1173,14 +1230,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -1456,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +39,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2206,8 +2206,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
|
||||
@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
@@ -445,6 +449,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
@@ -631,7 +632,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@@ -648,7 +650,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@@ -663,8 +666,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
// Default effort applies (high → xhigh) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@@ -676,7 +680,93 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
// Default effort applies (high → xhigh) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// output_config.effort override tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "low", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
// No thinking field, but output_config.effort="medium" → creates reasoning.
|
||||
// medium→high after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "medium"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "xhigh".
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "high"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default xhigh regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default xhigh.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -733,3 +823,188 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Image content block conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_UserImageBlock(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"text","text":"What is in this image?"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "What is in this image?", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Read the screenshot"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_1","content":[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output (no image).
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "fc_toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "(empty)", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultMixed(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Describe the file"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_2","content":[
|
||||
{"type":"text","text":"File metadata: 800x600 PNG"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output.
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Check weather"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"call_1","content":[
|
||||
{"type":"text","text":"Sunny, 72°F"}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Text-only tool_result should produce a plain string.
|
||||
assert.Equal(t, "Sunny, 72°F", items[2].Output)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
// Should default to image/png when media_type is empty.
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
@@ -45,18 +45,16 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Convert thinking → reasoning.
|
||||
// generate_summary="auto" causes the upstream to emit reasoning_summary_text
|
||||
// streaming events; the include array only needs reasoning.encrypted_content
|
||||
// (already set above) for content continuity.
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "enabled":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "high", Summary: "auto"}
|
||||
case "adaptive":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "medium", Summary: "auto"}
|
||||
}
|
||||
// "disabled" or unknown → omit reasoning
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
@@ -169,7 +167,7 @@ func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, err
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items.
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
@@ -184,28 +182,46 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
text := extractAnthropicToolResultText(b)
|
||||
if text == "" {
|
||||
// OpenAI Responses API requires "output" field; use placeholder for empty results.
|
||||
text = "(empty)"
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: text,
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text blocks → user message.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
content, _ := json.Marshal(text)
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
@@ -261,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -290,26 +305,64 @@ func fromResponsesCallID(id string) string {
|
||||
return id
|
||||
}
|
||||
|
||||
// extractAnthropicToolResultText gets the text content from a tool_result block.
|
||||
func extractAnthropicToolResultText(b AnthropicContentBlock) string {
|
||||
if len(b.Content) == 0 {
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
return s
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err == nil {
|
||||
var parts []string
|
||||
for _, ib := range inner {
|
||||
if ib.Type == "text" && ib.Text != "" {
|
||||
parts = append(parts, ib.Text)
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
@@ -324,6 +377,23 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
|
||||
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
810
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,810 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ChatCompletionsToResponses tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", resp.Model)
|
||||
assert.True(t, resp.Stream) // always forced true
|
||||
assert.False(t, *resp.Store)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
assert.Equal(t, "user", items[1].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "ping",
|
||||
Arguments: `{"host":"example.com"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: json.RawMessage(`"pong"`),
|
||||
},
|
||||
},
|
||||
Tools: []ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &ChatFunction{
|
||||
Name: "ping",
|
||||
Description: "Ping a host",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Equal(t, "pong", items[2].Output)
|
||||
|
||||
// Check tools
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||
t.Run("max_tokens", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
// Below minMaxOutputTokens (128), should be clamped
|
||||
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||
})
|
||||
|
||||
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
maxCompletion := 500
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
MaxCompletionTokens: &maxCompletion,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ReasoningEffort: "high",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(content)},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
Functions: []ChatFunction{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
|
||||
// tool_choice should be converted
|
||||
require.NotNil(t, resp.ToolChoice)
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ServiceTier: "flex",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "flex", resp.ServiceTier)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`"Let me call a function."`),
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "do_thing",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant message (with text) + function_call
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToChatCompletions tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_123",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Hello, world!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
assert.Equal(t, "chat.completion", chat.Object)
|
||||
assert.Equal(t, "gpt-4o", chat.Model)
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "Hello, world!", content)
|
||||
|
||||
require.NotNil(t, chat.Usage)
|
||||
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_xyz",
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city":"NYC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||
|
||||
msg := chat.Choices[0].Message
|
||||
require.Len(t, msg.ToolCalls, 1)
|
||||
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_789",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "reasoning",
|
||||
Summary: []ResponsesSummary{
|
||||
{Type: "summary_text", Text: "I thought about it."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "The answer is 42."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "partial..."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cache",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 10,
|
||||
TotalTokens: 110,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 80,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.NotNil(t, chat.Usage)
|
||||
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_ws",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "web_search_call",
|
||||
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "search results", content)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesEventToChatChunks tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
// response.created → role chunk
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{
|
||||
ID: "resp_stream",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||
assert.True(t, state.SentRole)
|
||||
|
||||
// response.output_text.delta → content chunk
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Hello",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 1,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_1",
|
||||
Name: "get_weather",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
assert.Equal(t, "call_1", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index)
|
||||
|
||||
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||
Delta: `{"city":`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||
|
||||
// Add a second function call at output_index=2
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 2,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_2",
|
||||
Name: "get_time",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||
|
||||
// Argument delta for second tool call
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 2,
|
||||
Delta: `{"tz":"UTC"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||
|
||||
// Argument delta for first tool call (interleaved)
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1,
|
||||
Delta: `"Tokyo"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 50,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 70,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
// finish chunk + usage chunk
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// First chunk: finish_reason
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Second chunk: usage
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SawToolCall = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
state.Usage = &ChatUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
|
||||
chunks := FinalizeResponsesChatStream(state)
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||
|
||||
// Idempotent: second call returns nil
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
// Simulate response.completed
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||
|
||||
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestChatChunkToSSE(t *testing.T) {
|
||||
chunk := ChatCompletionsChunk{
|
||||
ID: "chatcmpl-test",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1700000000,
|
||||
Model: "gpt-4o",
|
||||
Choices: []ChatChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Role: "assistant"},
|
||||
FinishReason: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sse, err := ChatChunkToSSE(chunk)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, "data: ")
|
||||
assert.Contains(t, sse, "chatcmpl-test")
|
||||
assert.Contains(t, sse, "assistant")
|
||||
assert.True(t, len(sse) > 10)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream round-trip test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
var allChunks []ChatCompletionsChunk
|
||||
|
||||
// 1. response.created
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// 2. text deltas
|
||||
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: text,
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
}
|
||||
|
||||
// 3. response.completed
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 4,
|
||||
TotalTokens: 14,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||
require.Len(t, allChunks, 7)
|
||||
|
||||
// First chunk has role
|
||||
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||
|
||||
// Text chunks
|
||||
var fullText string
|
||||
for i := 1; i <= 4; i++ {
|
||||
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||
}
|
||||
assert.Equal(t, "Hello, world!", fullText)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, allChunks[6].Usage)
|
||||
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||
|
||||
// All chunks share the same ID
|
||||
for _, c := range allChunks {
|
||||
assert.Equal(t, "resp_rt", c.ID)
|
||||
}
|
||||
}
|
||||
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
385
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
// Responses API request. The upstream always streams, so Stream is forced to
|
||||
// true. store is always false and reasoning.encrypted_content is always
|
||||
// included so that the response translator has full context.
|
||||
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: true, // upstream always streams
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||
maxTokens := 0
|
||||
if req.MaxTokens != nil {
|
||||
maxTokens = *req.MaxTokens
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
maxTokens = *req.MaxCompletionTokens
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
v := maxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||
if req.ReasoningEffort != "" {
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: req.ReasoningEffort,
|
||||
Summary: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// tools[] and legacy functions[] → ResponsesTool[]
|
||||
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||
}
|
||||
|
||||
// tool_choice: already compatible format — pass through directly.
|
||||
// Legacy function_call needs mapping.
|
||||
if len(req.ToolChoice) > 0 {
|
||||
out.ToolChoice = req.ToolChoice
|
||||
} else if len(req.FunctionCall) > 0 {
|
||||
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||
// array into a Responses API input items array.
|
||||
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
for _, m := range msgs {
|
||||
items, err := chatMessageToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||
// ResponsesInputItem values.
|
||||
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
return chatSystemToResponses(m)
|
||||
case "user":
|
||||
return chatUserToResponses(m)
|
||||
case "assistant":
|
||||
return chatAssistantToResponses(m)
|
||||
case "tool":
|
||||
return chatToolToResponses(m)
|
||||
case "function":
|
||||
return chatFunctionToResponses(m)
|
||||
default:
|
||||
return chatUserToResponses(m)
|
||||
}
|
||||
}
|
||||
|
||||
// chatSystemToResponses converts a system message.
|
||||
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
text, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, err := json.Marshal(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatUserToResponses converts a user message, handling both plain strings and
|
||||
// multi-modal content arrays.
|
||||
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string first.
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||
return nil, fmt.Errorf("parse user content: %w", err)
|
||||
}
|
||||
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := json.Marshal(responseParts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatAssistantToResponses converts an assistant message. If there is both
|
||||
// text content and tool_calls, the text is emitted as an assistant message
|
||||
// first, then each tool_call becomes a function_call item. If the content is
|
||||
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
}
|
||||
|
||||
// Emit one function_call item per tool_call.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := tc.Function.Arguments
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.ToolCallID,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// chatFunctionToResponses converts a legacy function result message
|
||||
// (role=function) into a function_call_output item. The Name field is used as
|
||||
// call_id since legacy function calls do not carry a separate call_id.
|
||||
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.Name,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return "", fmt.Errorf("parse content as string: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||
// function definitions to Responses API tool definitions.
|
||||
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" || t.Function == nil {
|
||||
continue
|
||||
}
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: t.Function.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
// Legacy functions[] are treated as function-type tools.
|
||||
for _, f := range functions {
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
Strict: f.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||
// Responses API tool_choice value.
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Object form: {"name":"X"}
|
||||
var obj struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
})
|
||||
}
|
||||
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
374
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||
// Completions response. Text output items are concatenated into
|
||||
// choices[0].message.content; function_call items become tool_calls.
|
||||
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||
id := resp.ID
|
||||
if id == "" {
|
||||
id = generateChatCmplID()
|
||||
}
|
||||
|
||||
out := &ChatCompletionsResponse{
|
||||
ID: id,
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
contentText += part.Text
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
toolCalls = append(toolCalls, ChatToolCall{
|
||||
ID: item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: item.Name,
|
||||
Arguments: item.Arguments,
|
||||
},
|
||||
})
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
// silently consumed — results already incorporated into text output
|
||||
}
|
||||
}
|
||||
|
||||
msg := ChatMessage{Role: "assistant"}
|
||||
if len(toolCalls) > 0 {
|
||||
msg.ToolCalls = toolCalls
|
||||
}
|
||||
if contentText != "" {
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
out.Choices = []ChatChoice{{
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
}}
|
||||
|
||||
if resp.Usage != nil {
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: resp.Usage.InputTokens,
|
||||
CompletionTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
out.Usage = usage
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "length"
|
||||
}
|
||||
return "stop"
|
||||
case "completed":
|
||||
if len(toolCalls) > 0 {
|
||||
return "tool_calls"
|
||||
}
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||
// SSE events into Chat Completions SSE chunks.
|
||||
type ResponsesEventToChatState struct {
|
||||
ID string
|
||||
Model string
|
||||
Created int64
|
||||
SentRole bool
|
||||
SawToolCall bool
|
||||
SawText bool
|
||||
Finalized bool // true after finish chunk has been emitted
|
||||
NextToolCallIndex int // next sequential tool_call index to assign
|
||||
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||
IncludeUsage bool
|
||||
Usage *ChatUsage
|
||||
}
|
||||
|
||||
// NewResponsesEventToChatState returns an initialised stream state.
|
||||
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||
return &ResponsesEventToChatState{
|
||||
ID: generateChatCmplID(),
|
||||
Created: time.Now().Unix(),
|
||||
OutputIndexToToolIndex: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||
// or more Chat Completions chunks, updating state as it goes.
|
||||
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToChatHandleCreated(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToChatHandleTextDelta(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToChatHandleOutputItemAdded(evt, state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||
// this returns nil.
|
||||
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if state.Finalized {
|
||||
return nil
|
||||
}
|
||||
state.Finalized = true
|
||||
|
||||
finishReason := "stop"
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Response != nil {
|
||||
if evt.Response.ID != "" {
|
||||
state.ID = evt.Response.ID
|
||||
}
|
||||
if state.Model == "" && evt.Response.Model != "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
// Emit the role chunk.
|
||||
if state.SentRole {
|
||||
return nil
|
||||
}
|
||||
state.SentRole = true
|
||||
|
||||
role := "assistant"
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||
}
|
||||
|
||||
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
state.SawText = true
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.SawToolCall = true
|
||||
idx := state.NextToolCallIndex
|
||||
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||
state.NextToolCallIndex++
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
ID: evt.Item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: evt.Item.Name,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
Function: ChatFunctionCall{
|
||||
Arguments: evt.Delta,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
state.Finalized = true
|
||||
finishReason := "stop"
|
||||
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
u := evt.Response.Usage
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
state.Usage = usage
|
||||
}
|
||||
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
case "completed":
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
}
|
||||
} else if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
var chunks []ChatCompletionsChunk
|
||||
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||
empty := ""
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Content: &empty},
|
||||
FinishReason: &finishReason,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||
func generateChatCmplID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "chatcmpl-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -12,17 +12,23 @@ import "encoding/json"
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
@@ -47,6 +53,9 @@ type AnthropicContentBlock struct {
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -58,9 +67,16 @@ type AnthropicContentBlock struct {
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
@@ -146,6 +162,7 @@ type ResponsesRequest struct {
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
@@ -176,8 +193,9 @@ type ResponsesInputItem struct {
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
@@ -311,6 +329,150 @@ type ResponsesStreamEvent struct {
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Chat Completions API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||
type ChatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Tools []ChatTool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||
|
||||
// Legacy function calling (deprecated but still supported)
|
||||
Functions []ChatFunction `json:"functions,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamOptions configures streaming behavior.
|
||||
type ChatStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatContentPart is a typed content part in a multi-modal message.
|
||||
type ChatContentPart struct {
|
||||
Type string `json:"type"` // "text" | "image_url"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// ChatImageURL contains the URL for an image content part.
|
||||
type ChatImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||
}
|
||||
|
||||
// ChatTool describes a tool available to the model.
|
||||
type ChatTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Function *ChatFunction `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// ChatFunction describes a function tool definition.
|
||||
type ChatFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ChatToolCall represents a tool call made by the assistant.
|
||||
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||
type ChatToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "function"
|
||||
Function ChatFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// ChatFunctionCall contains the function name and arguments.
|
||||
type ChatFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||
type ChatCompletionsResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChoice is a single completion choice.
|
||||
type ChatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||
}
|
||||
|
||||
// ChatUsage holds token counts in Chat Completions format.
|
||||
type ChatUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ChatTokenDetails provides a breakdown of token usage.
|
||||
type ChatTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||
type ChatCompletionsChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion.chunk"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChunkChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||
type ChatChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta ChatDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||
}
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
||||
return []Model{
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byName := make(map[string]Model, len(models))
|
||||
for _, model := range models {
|
||||
byName[model.Name] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"models/gemini-2.5-flash-image",
|
||||
"models/gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, name := range required {
|
||||
model, ok := byName[name]
|
||||
if !ok {
|
||||
t.Fatalf("expected fallback model %q to exist", name)
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) == 0 {
|
||||
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,12 @@ type Model struct {
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
byID := make(map[string]Model, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
byID[model.ID] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, id := range required {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
claims, err := DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
PlanType string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
|
||||
@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserSpendingRankingItem represents a user spending ranking row.
|
||||
type UserSpendingRankingItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -50,6 +51,18 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
"codex_primary_",
|
||||
"codex_secondary_",
|
||||
"codex_5h_",
|
||||
"codex_7d_",
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||
"codex_usage_updated_at": {},
|
||||
"session_window_utilization": {},
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -384,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -659,13 +672,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -925,6 +935,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1051,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1186,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
} else {
|
||||
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||
if len(updates) == 0 {
|
||||
return false
|
||||
}
|
||||
for key := range updates {
|
||||
if isSchedulerNeutralExtraKey(key) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSchedulerNeutralExtraKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
@@ -1676,13 +1724,139 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
@@ -1704,7 +1878,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
@@ -1713,14 +1887,14 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
if s.accounts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.accounts[accountID], nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
if s.accounts == nil {
|
||||
s.accounts = make(map[int64]*service.Account)
|
||||
}
|
||||
if account != nil {
|
||||
s.accounts[account.ID] = account
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -558,6 +597,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
@@ -603,6 +662,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-neutral",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Platform: account.Platform,
|
||||
Status: service.StatusDisabled,
|
||||
Extra: map[string]any{
|
||||
"codex_usage_updated_at": "old",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
updates := map[string]any{
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
"codex_5h_used_percent": 88.5,
|
||||
"session_window_utilization": 0.42,
|
||||
}
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||
|
||||
var outboxCount int
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||
s.Require().Zero(outboxCount)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-codex-exhausted",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||
"codex_7d_reset_after_seconds": 86400,
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(0, count)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-mixed",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, count)
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
|
||||
@@ -165,8 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||
query := `
|
||||
UPDATE api_keys
|
||||
SET
|
||||
quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL
|
||||
RETURNING quota_used, quota, key, status
|
||||
`
|
||||
|
||||
state := &service.APIKeyQuotaUsageState{}
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
@@ -476,8 +502,8 @@ func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -491,9 +517,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
|
||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||
user := s.mustCreateUser("quota-state@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||
key.Quota = 3
|
||||
key.QuotaUsed = 1
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||
|
||||
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||
s.Require().NotNil(state)
|
||||
s.Require().Equal(3.5, state.QuotaUsed)
|
||||
s.Require().Equal(3.0, state.Quota)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||
s.Require().Equal(key.Key, state.Key)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(3.5, got.QuotaUsed)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -147,17 +147,47 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.Quota != 0 {
|
||||
create.SetQuota(k.Quota)
|
||||
}
|
||||
if k.QuotaUsed != 0 {
|
||||
create.SetQuotaUsed(k.QuotaUsed)
|
||||
}
|
||||
if k.RateLimit5h != 0 {
|
||||
create.SetRateLimit5h(k.RateLimit5h)
|
||||
}
|
||||
if k.RateLimit1d != 0 {
|
||||
create.SetRateLimit1d(k.RateLimit1d)
|
||||
}
|
||||
if k.RateLimit7d != 0 {
|
||||
create.SetRateLimit7d(k.RateLimit7d)
|
||||
}
|
||||
if k.Usage5h != 0 {
|
||||
create.SetUsage5h(k.Usage5h)
|
||||
}
|
||||
if k.Usage1d != 0 {
|
||||
create.SetUsage1d(k.Usage1d)
|
||||
}
|
||||
if k.Usage7d != 0 {
|
||||
create.SetUsage7d(k.Usage7d)
|
||||
}
|
||||
if k.Window5hStart != nil {
|
||||
create.SetWindow5hStart(*k.Window5hStart)
|
||||
}
|
||||
if k.Window1dStart != nil {
|
||||
create.SetWindow1dStart(*k.Window1dStart)
|
||||
}
|
||||
if k.Window7dStart != nil {
|
||||
create.SetWindow7dStart(*k.Window7dStart)
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
create.SetExpiresAt(*k.ExpiresAt)
|
||||
}
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -16,19 +16,7 @@ type opsRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
q := `
|
||||
const insertOpsErrorLogSQL = `
|
||||
INSERT INTO ops_error_logs (
|
||||
request_id,
|
||||
client_request_id,
|
||||
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||
) RETURNING id`
|
||||
)`
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
var id int64
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
insertOpsErrorLogSQL+" RETURNING id",
|
||||
opsInsertErrorLogArgs(input)...,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = stmt.Close()
|
||||
}()
|
||||
|
||||
var inserted int64
|
||||
for _, input := range inputs {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
return inserted, nil
|
||||
}
|
||||
|
||||
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||
return []any{
|
||||
opsNullString(input.RequestID),
|
||||
opsNullString(input.ClientRequestID),
|
||||
opsNullInt64(input.UserID),
|
||||
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
|
||||
input.IsRetryable,
|
||||
input.RetryCount,
|
||||
input.CreatedAt,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||
|
||||
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||
now := time.Now().UTC()
|
||||
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||
{
|
||||
RequestID: "batch-ops-1",
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
Severity: "error",
|
||||
StatusCode: 429,
|
||||
ErrorMessage: "rate limited",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
RequestID: "batch-ops-2",
|
||||
ErrorPhase: "internal",
|
||||
ErrorType: "api_error",
|
||||
Severity: "error",
|
||||
StatusCode: 500,
|
||||
ErrorMessage: "internal error",
|
||||
CreatedAt: now.Add(time.Millisecond),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, inserted)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(12345)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(67890)
|
||||
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
|
||||
// This is exported for use by OpenAIPrivacyService
|
||||
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
|
||||
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
const schedulerOutboxDedupWindow = time.Second
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
|
||||
}
|
||||
payloadArg = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
query := `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadArg)
|
||||
`
|
||||
args := []any{eventType, accountID, groupID, payloadArg}
|
||||
if schedulerOutboxEventSupportsDedup(eventType) {
|
||||
query = `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
SELECT $1, $2, $3, $4
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM scheduler_outbox
|
||||
WHERE event_type = $1
|
||||
AND account_id IS NOT DISTINCT FROM $2
|
||||
AND group_id IS NOT DISTINCT FROM $3
|
||||
AND created_at >= NOW() - make_interval(secs => $5)
|
||||
)
|
||||
`
|
||||
args = append(args, schedulerOutboxDedupWindow.Seconds())
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func schedulerOutboxEventSupportsDedup(eventType string) bool {
|
||||
switch eventType {
|
||||
case service.SchedulerOutboxEventAccountChanged,
|
||||
service.SchedulerOutboxEventGroupChanged,
|
||||
service.SchedulerOutboxEventFullRebuild:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||
|
||||
const total = 16
|
||||
results := make([]bool, total)
|
||||
errs := make([]error, total)
|
||||
logs := make([]*service.UsageLog, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
logs[i] = &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.True(t, results[i])
|
||||
require.NotZero(t, logs[i].ID)
|
||||
}
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||
require.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
inserted1, err1 := repo.Create(ctx, log1)
|
||||
inserted2, err2 := repo.Create(ctx, log2)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.True(t, inserted1)
|
||||
require.False(t, inserted2)
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||
|
||||
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
repo.createBatchCh <- usageLogCreateRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -183,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||
|
||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||
WithArgs(start, end, 12).
|
||||
WillReturnRows(rows)
|
||||
|
||||
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
|
||||
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
|
||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||
},
|
||||
TotalActualCost: 40.0,
|
||||
}, got)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -280,11 +374,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +413,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-batch-no-update",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 1.2,
|
||||
ActualCost: 1.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
|
||||
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||
})
|
||||
|
||||
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||
}
|
||||
|
||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id
|
||||
WHERE ugr.group_id = $1
|
||||
ORDER BY ugr.user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
@@ -164,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
userIDs := make([]int64, len(entries))
|
||||
rates := make([]float64, len(entries))
|
||||
for i, e := range entries {
|
||||
userIDs[i] = e.UserID
|
||||
rates[i] = e.RateMultiplier
|
||||
}
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
|
||||
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user