mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 08:20:23 +08:00
Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
b43ee62947 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
c28f691f32 |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -42,6 +42,6 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.11
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
28
README.md
28
README.md
@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +522,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
28
README_CN.md
28
README_CN.md
@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -588,6 +588,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -162,7 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -229,7 +229,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -1402,7 +1402,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -660,6 +662,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
@@ -715,52 +753,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -770,10 +787,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -792,37 +808,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -844,20 +850,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -913,14 +950,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -1348,6 +1348,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -98,6 +98,19 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -125,9 +138,9 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
@@ -255,11 +268,19 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -475,6 +496,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -161,6 +161,26 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -193,8 +196,12 @@ type Account struct {
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -315,6 +322,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -652,6 +652,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -972,33 +979,45 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
@@ -155,6 +156,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -212,6 +213,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -259,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -288,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -538,9 +560,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -602,6 +640,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -641,6 +680,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -1456,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -2207,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -445,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -631,7 +631,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@@ -648,7 +649,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@@ -663,8 +665,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
// Default effort applies (high → xhigh) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@@ -676,7 +679,93 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
// Default effort applies (high → xhigh) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// output_config.effort override tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "low", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
// No thinking field, but output_config.effort="medium" → creates reasoning.
|
||||
// medium→high after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "medium"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "xhigh".
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "high"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default xhigh regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default xhigh.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -733,3 +822,188 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Image content block conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_UserImageBlock(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"text","text":"What is in this image?"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "What is in this image?", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Read the screenshot"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_1","content":[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output (no image).
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "fc_toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "(empty)", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultMixed(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Describe the file"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_2","content":[
|
||||
{"type":"text","text":"File metadata: 800x600 PNG"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output.
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Check weather"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"call_1","content":[
|
||||
{"type":"text","text":"Sunny, 72°F"}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Text-only tool_result should produce a plain string.
|
||||
assert.Equal(t, "Sunny, 72°F", items[2].Output)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
// Should default to image/png when media_type is empty.
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
@@ -45,18 +45,16 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Convert thinking → reasoning.
|
||||
// generate_summary="auto" causes the upstream to emit reasoning_summary_text
|
||||
// streaming events; the include array only needs reasoning.encrypted_content
|
||||
// (already set above) for content continuity.
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "enabled":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "high", Summary: "auto"}
|
||||
case "adaptive":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "medium", Summary: "auto"}
|
||||
}
|
||||
// "disabled" or unknown → omit reasoning
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
@@ -169,7 +167,7 @@ func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, err
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items.
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
@@ -184,28 +182,46 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
text := extractAnthropicToolResultText(b)
|
||||
if text == "" {
|
||||
// OpenAI Responses API requires "output" field; use placeholder for empty results.
|
||||
text = "(empty)"
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: text,
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text blocks → user message.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
content, _ := json.Marshal(text)
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
@@ -290,26 +306,64 @@ func fromResponsesCallID(id string) string {
|
||||
return id
|
||||
}
|
||||
|
||||
// extractAnthropicToolResultText gets the text content from a tool_result block.
|
||||
func extractAnthropicToolResultText(b AnthropicContentBlock) string {
|
||||
if len(b.Content) == 0 {
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
return s
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err == nil {
|
||||
var parts []string
|
||||
for _, ib := range inner {
|
||||
if ib.Type == "text" && ib.Text != "" {
|
||||
parts = append(parts, ib.Text)
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
@@ -324,6 +378,23 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
|
||||
@@ -12,17 +12,23 @@ import "encoding/json"
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
@@ -47,6 +53,9 @@ type AnthropicContentBlock struct {
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -58,9 +67,16 @@ type AnthropicContentBlock struct {
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
@@ -146,6 +162,7 @@ type ResponsesRequest struct {
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
@@ -176,8 +193,9 @@ type ResponsesInputItem struct {
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
claims, err := DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
PlanType string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
|
||||
@@ -659,13 +659,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -925,6 +922,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1038,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1676,13 +1675,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
@@ -1704,7 +1737,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
@@ -1713,14 +1746,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
|
||||
@@ -165,8 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -476,8 +476,8 @@ func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -491,9 +491,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
|
||||
@@ -147,17 +147,47 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -643,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -263,6 +264,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
@@ -392,6 +395,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
|
||||
@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPoolModeRetryCount = 3
|
||||
maxPoolModeRetryCount = 10
|
||||
)
|
||||
|
||||
// GetPoolModeRetryCount 返回池模式同账号重试次数。
|
||||
// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。
|
||||
func (a *Account) GetPoolModeRetryCount() int {
|
||||
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
raw, ok := a.Credentials["pool_mode_retry_count"]
|
||||
if !ok || raw == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
count := parsePoolModeRetryCount(raw)
|
||||
if count < 0 {
|
||||
return 0
|
||||
}
|
||||
if count > maxPoolModeRetryCount {
|
||||
return maxPoolModeRetryCount
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePoolModeRetryCount(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
|
||||
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
|
||||
func isPoolModeRetryableStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -1134,33 +1203,97 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
return a.getExtraFloat64("quota_limit")
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
}
|
||||
|
||||
// GetQuotaDailyUsed 获取当日已用额度(美元)
|
||||
func (a *Account) GetQuotaDailyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_daily_used")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaWeeklyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_limit")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
|
||||
func (a *Account) GetQuotaWeeklyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_used")
|
||||
}
|
||||
|
||||
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
|
||||
func (a *Account) getExtraFloat64(key string) float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
|
||||
func (a *Account) getExtraTime(key string) time.Time {
|
||||
if a.Extra == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||
return t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
}
|
||||
|
||||
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期
|
||||
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true // 从未使用过,视为过期(下次 increment 会初始化)
|
||||
}
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
|
||||
117
backend/internal/service/account_pool_mode_test.go
Normal file
117
backend/internal/service/account_pool_mode_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,9 +68,9 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
|
||||
@@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
|
||||
102
backend/internal/service/account_test_service_openai_test.go
Normal file
102
backend/internal/service/account_test_service_openai_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
@@ -367,7 +368,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
@@ -409,6 +410,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if usage == nil {
|
||||
return true
|
||||
}
|
||||
if usage.FiveHour == nil || usage.SevenDay == nil {
|
||||
return true
|
||||
}
|
||||
if account.IsRateLimited() {
|
||||
return true
|
||||
}
|
||||
return isOpenAICodexSnapshotStale(account, now)
|
||||
}
|
||||
|
||||
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
|
||||
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return false
|
||||
}
|
||||
if account.Extra == nil {
|
||||
return true
|
||||
}
|
||||
raw, ok := account.Extra["codex_usage_updated_at"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, err := parseTime(fmt.Sprint(raw))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(ts) >= openAIProbeCacheTTL
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
@@ -478,20 +513,34 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
68
backend/internal/service/account_usage_service_test.go
Normal file
68
backend/internal/service/account_usage_service_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
FiveHour: &UsageProgress{Utilization: 0},
|
||||
SevenDay: &UsageProgress{Utilization: 0},
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
|
||||
t.Fatal("expected rate-limited account to force codex snapshot refresh")
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
|
||||
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
|
||||
t.Fatal("expected missing 5h snapshot to require refresh")
|
||||
}
|
||||
|
||||
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"codex_usage_updated_at": staleAt,
|
||||
},
|
||||
}, usage, now) {
|
||||
t.Fatal("expected stale ws snapshot to trigger refresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if got := updates["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
@@ -432,6 +432,7 @@ type adminServiceImpl struct {
|
||||
entClient *dbent.Client // 用于开启数据库事务
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@@ -459,6 +460,7 @@ func NewAdminService(
|
||||
entClient *dbent.Client,
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -476,6 +478,7 @@ func NewAdminService(
|
||||
entClient: entClient,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1277,9 +1280,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
if group.Status != StatusActive {
|
||||
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
|
||||
}
|
||||
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
|
||||
// 订阅类型分组:用户须持有该分组的有效订阅才可绑定
|
||||
if group.IsSubscriptionType() {
|
||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
|
||||
if s.userSubRepo == nil {
|
||||
return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured")
|
||||
}
|
||||
if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil {
|
||||
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
gid := *groupID
|
||||
@@ -1287,7 +1298,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
apiKey.Group = group
|
||||
|
||||
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
|
||||
if group.IsExclusive {
|
||||
if group.IsExclusive && !group.IsSubscriptionType() {
|
||||
opCtx := ctx
|
||||
var tx *dbent.Tx
|
||||
if s.entClient == nil {
|
||||
@@ -1349,6 +1360,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range accounts {
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
@@ -1484,9 +1499,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
// 保留配额用量字段,防止编辑账号时意外重置
|
||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||
if v, ok := account.Extra[key]; ok {
|
||||
input.Extra[key] = v
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
@@ -1717,16 +1734,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err := s.accountRepo.ClearError(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
|
||||
@@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context,
|
||||
return s.addGroupErr
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
|
||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type apiKeyRepoStubForGroupUpdate struct {
|
||||
@@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
type userSubRepoStubForGroupUpdate struct {
|
||||
userSubRepoNoop
|
||||
getActiveSub *UserSubscription
|
||||
getActiveErr error
|
||||
called bool
|
||||
calledUserID int64
|
||||
calledGroupID int64
|
||||
}
|
||||
|
||||
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
s.called = true
|
||||
s.calledUserID = userID
|
||||
s.calledGroupID = groupID
|
||||
if s.getActiveErr != nil {
|
||||
return nil, s.getActiveErr
|
||||
}
|
||||
if s.getActiveSub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *s.getActiveSub
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
// 无有效订阅时应拒绝绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
|
||||
require.True(t, userSubRepo.called)
|
||||
require.Equal(t, int64(42), userSubRepo.calledUserID)
|
||||
require.Equal(t, int64(10), userSubRepo.calledGroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
// 订阅类型分组应被阻止绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
|
||||
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{
|
||||
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.True(t, userSubRepo.called)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
|
||||
@@ -1384,7 +1384,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody, maxBytes := s.getLogConfig()
|
||||
@@ -1517,6 +1517,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
|
||||
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
|
||||
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: s.getUpstreamErrorDetail(respBody),
|
||||
})
|
||||
|
||||
// 修正 claudeReq 的 thinking 参数(adaptive 模式不修正)
|
||||
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
|
||||
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: BudgetRectifyBudgetTokens,
|
||||
}
|
||||
if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens {
|
||||
retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
|
||||
if txErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: action,
|
||||
body: retryGeminiBody,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
})
|
||||
if retryErr == nil {
|
||||
retryResp := retryResult.resp
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||||
|
||||
@@ -78,11 +78,11 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -21,24 +22,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
@@ -58,6 +60,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
entClient *dbent.Client
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
@@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
entClient *dbent.Client,
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
@@ -88,6 +92,7 @@ func NewAuthService(
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
@@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return nil, nil, ErrOAuthInvitationRequired
|
||||
}
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
@@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
if s.entClient != nil && invitationRedeemCode != nil {
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
if err := s.userRepo.Create(txCtx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
|
||||
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthServiceForPendingOAuthTest() *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-pending-oauth",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
|
||||
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
email, username, err := svc.VerifyPendingOAuthToken(token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user@example.com", email)
|
||||
require.Equal(t, "alice", username)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
|
||||
accessToken, err := svc.GenerateToken(&User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: RoleUser,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "some_other_purpose",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "", // 旧 token 无此字段,反序列化后为零值
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(past),
|
||||
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
|
||||
other := NewAuthService(nil, nil, nil, nil, &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "other-secret"},
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
_, _, err = svc.VerifyPendingOAuthToken(token)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
giant := make([]byte, maxTokenLength+1)
|
||||
for i := range giant {
|
||||
giant[i] = 'a'
|
||||
}
|
||||
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
@@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
}
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
|
||||
@@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
||||
turnstileService := NewTurnstileService(settingService, verifier)
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
&userRepoStub{},
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
|
||||
@@ -43,16 +43,19 @@ type BillingCache interface {
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -61,6 +64,28 @@ const (
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
func normalizeBillingServiceTier(serviceTier string) string {
|
||||
return strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
}
|
||||
|
||||
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
|
||||
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
|
||||
return false
|
||||
}
|
||||
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
|
||||
}
|
||||
|
||||
func serviceTierCostMultiplier(serviceTier string) float64 {
|
||||
switch normalizeBillingServiceTier(serviceTier) {
|
||||
case "priority":
|
||||
return 2.0
|
||||
case "flex":
|
||||
return 0.5
|
||||
default:
|
||||
return 1.0
|
||||
}
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
@@ -173,30 +198,60 @@ func (s *BillingService) initFallbackPricing() {
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
CacheReadPricePerTokenPriority: 0.25e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
InputPricePerTokenPriority: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
InputPricePerTokenPriority: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
CacheReadPricePerTokenPriority: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
@@ -241,6 +296,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
return s.fallbackPrices["gpt-5.2"]
|
||||
case "gpt-5.2-codex":
|
||||
return s.fallbackPrices["gpt-5.2-codex"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
@@ -269,16 +328,19 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
@@ -295,6 +357,10 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -303,6 +369,21 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
if usePriorityServiceTierPricing(serviceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(serviceTier)
|
||||
}
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
@@ -329,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
|
||||
@@ -522,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) {
|
||||
require.False(t, math.IsNaN(cost.TotalCost))
|
||||
require.False(t, math.IsInf(cost.TotalCost, 0))
|
||||
}
|
||||
|
||||
func TestServiceTierCostMultiplier(t *testing.T) {
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
|
||||
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
|
||||
|
||||
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
|
||||
pricingSvc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.4": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
InputCostPerTokenPriority: 5e-6,
|
||||
OutputCostPerToken: 15e-6,
|
||||
OutputCostPerTokenPriority: 30e-6,
|
||||
CacheCreationInputTokenCost: 2.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
CacheReadInputTokenCostPriority: 0.5e-6,
|
||||
LongContextInputTokenThreshold: 272000,
|
||||
LongContextInputCostMultiplier: 2.0,
|
||||
LongContextOutputCostMultiplier: 1.5,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewBillingService(&config.Config{}, pricingSvc)
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52Codex)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"custom-no-priority": {
|
||||
InputCostPerToken: 1e-6,
|
||||
OutputCostPerToken: 2e-6,
|
||||
CacheCreationInputTokenCost: 0.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"dynamic-tier-model": {
|
||||
InputCostPerToken: 1e-6,
|
||||
InputCostPerTokenPriority: 2e-6,
|
||||
OutputCostPerToken: 3e-6,
|
||||
OutputCostPerTokenPriority: 6e-6,
|
||||
CacheCreationInputTokenCost: 4e-6,
|
||||
CacheCreationInputTokenCostAbove1hr: 5e-6,
|
||||
CacheReadInputTokenCost: 7e-7,
|
||||
CacheReadInputTokenCostPriority: 8e-7,
|
||||
LongContextInputTokenThreshold: 999,
|
||||
LongContextInputCostMultiplier: 1.5,
|
||||
LongContextOutputCostMultiplier: 1.25,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pricing, err := svc.GetModelPricing("dynamic-tier-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
require.True(t, pricing.SupportsCacheBreakdown)
|
||||
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 999, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
|
||||
// 启动时清理旧进程遗留槽位与等待计数
|
||||
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
|
||||
return "r" + strconv.FormatUint(fallback, 36)
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
||||
// Format: {process_random_prefix}-{base36_counter}
|
||||
func RequestIDPrefix() string {
|
||||
return requestIDPrefix
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
seq := requestIDCounter.Add(1)
|
||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
|
||||
@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
type trackingConcurrencyCache struct {
|
||||
stubConcurrencyCacheForTest
|
||||
cleanupPrefix string
|
||||
}
|
||||
|
||||
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
|
||||
c.cleanupPrefix = prefix
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||
}
|
||||
|
||||
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
|
||||
cache := &trackingConcurrencyCache{}
|
||||
svc := NewConcurrencyService(cache)
|
||||
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
@@ -175,6 +175,20 @@ const (
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
|
||||
// =========================
|
||||
// Request Rectifier (请求整流器)
|
||||
// =========================
|
||||
|
||||
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
|
||||
SettingKeyRectifierSettings = "rectifier_settings"
|
||||
|
||||
// =========================
|
||||
// Beta Policy Settings
|
||||
// =========================
|
||||
|
||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
{
|
||||
name: "pool_mode_custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 7,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401), float64(403)},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "pool_mode_without_custom_error_codes_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 8,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
|
||||
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 30,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.False(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
|
||||
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 31,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401)},
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "DroppedBetas removes both context-1m and fast-mode",
|
||||
name: "DroppedBetas is empty (filtering moved to configurable beta policy)",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||
tokens: claude.DroppedBetas,
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
|
||||
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
||||
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
|
||||
// DroppedBetas is now empty — filtering moved to configurable beta policy.
|
||||
// Without a policy filter set, nothing gets dropped from the static set.
|
||||
drop := droppedBetaSet()
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
require.NotContains(t, got, "fast-mode-2026-02-01")
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got)
|
||||
require.Contains(t, got, "context-1m-2025-08-07")
|
||||
require.Contains(t, got, "fast-mode-2026-02-01")
|
||||
}
|
||||
|
||||
func TestDroppedBetaSet(t *testing.T) {
|
||||
// Base set contains DroppedBetas
|
||||
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
|
||||
base := droppedBetaSet()
|
||||
require.Contains(t, base, claude.BetaContext1M)
|
||||
require.Contains(t, base, claude.BetaFastMode)
|
||||
require.Len(t, base, len(claude.DroppedBetas))
|
||||
|
||||
// With extra tokens
|
||||
extended := droppedBetaSet(claude.BetaClaudeCode)
|
||||
require.Contains(t, extended, claude.BetaContext1M)
|
||||
require.Contains(t, extended, claude.BetaFastMode)
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
}
|
||||
@@ -148,6 +146,32 @@ func TestBuildBetaTokenSet(t *testing.T) {
|
||||
require.Empty(t, empty)
|
||||
}
|
||||
|
||||
func TestContainsBetaToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
token string
|
||||
want bool
|
||||
}{
|
||||
{"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true},
|
||||
{"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false},
|
||||
{"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"empty header", "", "fast-mode-2026-02-01", false},
|
||||
{"empty token", "fast-mode-2026-02-01", "", false},
|
||||
{"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := containsBetaToken(tt.header, tt.token)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
|
||||
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
got := stripBetaTokensWithSet(header, map[string]struct{}{})
|
||||
|
||||
@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
result := make(map[int64]*UserLoadInfo, len(users))
|
||||
for _, user := range users {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -258,6 +259,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
if !hasEmptyContent && !containsThinkingBlocks {
|
||||
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
|
||||
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
return out
|
||||
}
|
||||
return body
|
||||
@@ -395,6 +397,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
} else {
|
||||
return body
|
||||
}
|
||||
// Removing "thinking" makes any context_management strategy that requires it invalid
|
||||
// (e.g. clear_thinking_20251015). Strip those entries so the retry request does not
|
||||
// receive a 400 "strategy requires thinking to be enabled or adaptive".
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
}
|
||||
if modified {
|
||||
msgsBytes, err := json.Marshal(messages)
|
||||
@@ -409,6 +415,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// removeThinkingDependentContextStrategies 从 context_management.edits 中移除
|
||||
// 需要 thinking 启用的策略(如 clear_thinking_20251015)。
|
||||
// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回
|
||||
// "strategy requires thinking to be enabled or adaptive"。
|
||||
func removeThinkingDependentContextStrategies(body []byte) []byte {
|
||||
jsonStr := *(*string)(unsafe.Pointer(&body))
|
||||
editsRes := gjson.Get(jsonStr, "context_management.edits")
|
||||
if !editsRes.Exists() || !editsRes.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
var filtered []json.RawMessage
|
||||
hasRemoved := false
|
||||
editsRes.ForEach(func(_, v gjson.Result) bool {
|
||||
if v.Get("type").String() == "clear_thinking_20251015" {
|
||||
hasRemoved = true
|
||||
return true
|
||||
}
|
||||
filtered = append(filtered, json.RawMessage(v.Raw))
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasRemoved {
|
||||
return body
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
filteredBytes, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||
// signature/thought_signature validation issues involving tool blocks.
|
||||
//
|
||||
@@ -444,6 +493,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
// Remove context_management strategies that require thinking to be enabled
|
||||
// (e.g. clear_thinking_20251015), otherwise upstream returns 400.
|
||||
if cm, ok := req["context_management"].(map[string]any); ok {
|
||||
if edits, ok := cm["edits"].([]any); ok {
|
||||
filtered := make([]any, 0, len(edits))
|
||||
for _, edit := range edits {
|
||||
if editMap, ok := edit.(map[string]any); ok {
|
||||
if editMap["type"] == "clear_thinking_20251015" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, edit)
|
||||
}
|
||||
if len(filtered) != len(edits) {
|
||||
if len(filtered) == 0 {
|
||||
delete(cm, "edits")
|
||||
} else {
|
||||
cm["edits"] = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
@@ -675,3 +746,90 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Thinking Budget Rectifier
|
||||
// =========================
|
||||
|
||||
const (
|
||||
// BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying.
|
||||
BudgetRectifyBudgetTokens = 32000
|
||||
// BudgetRectifyMaxTokens is the max_tokens value to set when rectifying.
|
||||
BudgetRectifyMaxTokens = 64000
|
||||
// BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens.
|
||||
BudgetRectifyMinMaxTokens = 32001
|
||||
)
|
||||
|
||||
// isThinkingBudgetConstraintError detects whether an upstream error message indicates
|
||||
// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024").
|
||||
// Matches three conditions (all must be true):
|
||||
// 1. Contains "budget_tokens" or "budget tokens"
|
||||
// 2. Contains "thinking"
|
||||
// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be")
|
||||
func isThinkingBudgetConstraintError(errMsg string) bool {
|
||||
m := strings.ToLower(errMsg)
|
||||
|
||||
// Condition 1: budget_tokens or budget tokens
|
||||
hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens")
|
||||
if !hasBudget {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 2: thinking
|
||||
if !strings.Contains(m, "thinking") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 3: constraint indicator
|
||||
if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(m, "1024") && strings.Contains(m, "input should be") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors.
|
||||
// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive),
|
||||
// and ensures max_tokens >= 32001.
|
||||
// Returns (modified body, true) if changes were applied, or (original body, false) if not.
|
||||
func RectifyThinkingBudget(body []byte) ([]byte, bool) {
|
||||
// If thinking type is "adaptive", skip rectification entirely
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType == "adaptive" {
|
||||
return body, false
|
||||
}
|
||||
|
||||
modified := body
|
||||
changed := false
|
||||
|
||||
// Set thinking.type = "enabled"
|
||||
if thinkingType != "enabled" {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set thinking.budget_tokens = 32000
|
||||
currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int()
|
||||
if currentBudget != BudgetRectifyBudgetTokens {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure max_tokens >= BudgetRectifyMinMaxTokens
|
||||
maxTokens := gjson.GetBytes(modified, "max_tokens").Int()
|
||||
if maxTokens < int64(BudgetRectifyMinMaxTokens) {
|
||||
if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified, changed
|
||||
}
|
||||
|
||||
@@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
|
||||
// ============ Group 6b: context_management.edits 清理测试 ============
|
||||
|
||||
// removeThinkingDependentContextStrategies — 边界用例
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) {
|
||||
input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "无 context_management 字段时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 为空数组时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留其他条目")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_strategy", edit0["type"])
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) {
|
||||
// 快速路径:messages 中无 thinking 块,仅有顶层 thinking 字段
|
||||
// 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除,edits 键应被删除")
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) {
|
||||
// 完整路径:messages 中有 thinking 块(非 fast path)
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"some thought","signature":"sig"},
|
||||
{"type":"text","text":"Answer"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 keep_this")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep_this", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) {
|
||||
// 无 context_management 时不应报错,且 thinking 正常被移除
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
_, hasCM := req["context_management"]
|
||||
require.False(t, hasCM)
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"thought","signature":"sig"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
if rawEdits, hasEdits := cm["edits"]; hasEdits {
|
||||
edits, ok := rawEdits.([]any)
|
||||
require.True(t, ok)
|
||||
for _, e := range edits {
|
||||
em, ok := e.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 other_edit")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_edit", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) {
|
||||
// 没有顶层 thinking 字段时,context_management 不应被修改
|
||||
input := []byte(`{
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改")
|
||||
}
|
||||
|
||||
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
|
||||
|
||||
// Task 7.1 — 类型校验边界测试
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 40 * 1024 * 1024
|
||||
defaultMaxLineSize = 500 * 1024 * 1024
|
||||
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
|
||||
// to match real Claude CLI traffic as closely as possible. When we need a visual
|
||||
// separator between system blocks, we add "\n\n" at concatenation time.
|
||||
@@ -526,6 +526,7 @@ type GatewayService struct {
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
settingService *SettingService
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
@@ -553,6 +554,7 @@ func NewGatewayService(
|
||||
sessionLimitCache SessionLimitCache,
|
||||
rpmCache RPMCache,
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -579,6 +581,7 @@ func NewGatewayService(
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
@@ -994,6 +997,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
||||
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
|
||||
}
|
||||
|
||||
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
||||
func GenerateSessionUUID(seed string) string {
|
||||
return generateSessionUUID(seed)
|
||||
}
|
||||
|
||||
func generateSessionUUID(seed string) string {
|
||||
if seed == "" {
|
||||
return uuid.NewString()
|
||||
@@ -3328,10 +3336,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
||||
}
|
||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
||||
return true
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@@ -3944,6 +3948,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||
if account.Platform == PlatformAnthropic && c != nil {
|
||||
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
|
||||
if policy.blockErr != nil {
|
||||
return nil, policy.blockErr
|
||||
}
|
||||
filterSet := policy.filterSet
|
||||
if filterSet == nil {
|
||||
filterSet = map[string]struct{}{}
|
||||
}
|
||||
c.Set(betaPolicyFilterSetKey, filterSet)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
@@ -4069,7 +4087,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@@ -4186,7 +4204,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
// 不是签名错误(或整流器已关闭),继续检查 budget 约束
|
||||
errMsg := extractUpstreamErrorMessage(respBody)
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = budgetRetryResp
|
||||
break
|
||||
}
|
||||
if budgetRetryResp != nil && budgetRetryResp.Body != nil {
|
||||
_ = budgetRetryResp.Body.Close()
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
@@ -4278,7 +4334,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4308,7 +4368,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
@@ -4543,7 +4607,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4573,7 +4641,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -5075,6 +5147,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||
}
|
||||
|
||||
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
||||
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
|
||||
effectiveDropSet := mergeDropSets(policyFilterSet)
|
||||
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
|
||||
|
||||
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
||||
if tokenType == "oauth" {
|
||||
if mimicClaudeCode {
|
||||
@@ -5088,17 +5165,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// messages requests typically use only oauth + interleaved-thinking.
|
||||
// Also drop claude-code beta if a downstream client added it.
|
||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet))
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
|
||||
} else {
|
||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet))
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
} else {
|
||||
// API-key accounts: apply beta policy filter to strip controlled tokens
|
||||
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5276,6 +5358,104 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string {
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
// BetaBlockedError indicates a request was blocked by a beta policy rule.
|
||||
type BetaBlockedError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *BetaBlockedError) Error() string { return e.Message }
|
||||
|
||||
// betaPolicyResult holds the evaluated result of beta policy rules for a single request.
|
||||
type betaPolicyResult struct {
|
||||
blockErr *BetaBlockedError // non-nil if a block rule matched
|
||||
filterSet map[string]struct{} // tokens to filter (may be nil)
|
||||
}
|
||||
|
||||
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
|
||||
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
|
||||
if s.settingService == nil {
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
settings, err := s.settingService.GetBetaPolicySettings(ctx)
|
||||
if err != nil || settings == nil {
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
var result betaPolicyResult
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
continue
|
||||
}
|
||||
switch rule.Action {
|
||||
case BetaPolicyActionBlock:
|
||||
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
|
||||
msg := rule.ErrorMessage
|
||||
if msg == "" {
|
||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||
}
|
||||
result.blockErr = &BetaBlockedError{Message: msg}
|
||||
}
|
||||
case BetaPolicyActionFilter:
|
||||
if result.filterSet == nil {
|
||||
result.filterSet = make(map[string]struct{})
|
||||
}
|
||||
result.filterSet[rule.BetaToken] = struct{}{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens.
|
||||
// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation).
|
||||
func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} {
|
||||
if len(policySet) == 0 && len(extra) == 0 {
|
||||
return defaultDroppedBetasSet
|
||||
}
|
||||
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra))
|
||||
for t := range defaultDroppedBetasSet {
|
||||
m[t] = struct{}{}
|
||||
}
|
||||
for t := range policySet {
|
||||
m[t] = struct{}{}
|
||||
}
|
||||
for _, t := range extra {
|
||||
m[t] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request.
|
||||
const betaPolicyFilterSetKey = "betaPolicyFilterSet"
|
||||
|
||||
// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available.
|
||||
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
|
||||
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
|
||||
// evaluates on demand (one DB call).
|
||||
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
|
||||
if c != nil {
|
||||
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
|
||||
if fs, ok := v.(map[string]struct{}); ok {
|
||||
return fs
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.evaluateBetaPolicy(ctx, "", account).filterSet
|
||||
}
|
||||
|
||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
|
||||
switch scope {
|
||||
case BetaPolicyScopeAll:
|
||||
return true
|
||||
case BetaPolicyScopeOAuth:
|
||||
return isOAuth
|
||||
case BetaPolicyScopeAPIKey:
|
||||
return !isOAuth
|
||||
default:
|
||||
return true // unknown scope → match all (fail-open)
|
||||
}
|
||||
}
|
||||
|
||||
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
||||
func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
||||
@@ -5288,6 +5468,19 @@ func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
return m
|
||||
}
|
||||
|
||||
// containsBetaToken checks if a comma-separated header value contains the given token.
|
||||
func containsBetaToken(header, token string) bool {
|
||||
if header == "" || token == "" {
|
||||
return false
|
||||
}
|
||||
for _, p := range strings.Split(header, ",") {
|
||||
if strings.TrimSpace(p) == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(tokens))
|
||||
for _, t := range tokens {
|
||||
@@ -5299,10 +5492,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
||||
return m
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
|
||||
droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode)
|
||||
)
|
||||
var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
|
||||
|
||||
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
||||
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
||||
@@ -6437,7 +6627,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -6928,7 +7118,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
@@ -7240,6 +7430,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
applyClaudeOAuthHeaderDefaults(req, false)
|
||||
}
|
||||
|
||||
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
||||
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
if mimicClaudeCode {
|
||||
@@ -7247,8 +7440,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
incomingBeta := req.Header.Get("anthropic-beta")
|
||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||
drop := droppedBetaSet()
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
|
||||
} else {
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
if clientBetaHeader == "" {
|
||||
@@ -7258,14 +7450,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||
beta = beta + "," + claude.BetaTokenCounting
|
||||
}
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet))
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
|
||||
}
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
} else {
|
||||
// API-key accounts: apply beta policy filter to strip controlled tokens
|
||||
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
|
||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
@@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -12,6 +12,78 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAISnapshotCacheStub struct {
|
||||
SchedulerCache
|
||||
snapshotAccounts []*Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
|
||||
if len(s.snapshotAccounts) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
out := make([]*Account, 0, len(s.snapshotAccounts))
|
||||
for _, account := range s.snapshotAccounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
cloned := *account
|
||||
out = append(out, &cloned)
|
||||
}
|
||||
return out, true, nil
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.accountsByID == nil {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.accountsByID[accountID]
|
||||
if account == nil {
|
||||
return nil, nil
|
||||
}
|
||||
cloned := *account
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10101)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(31002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10102)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -38,7 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
originalModel := anthropicReq.Model
|
||||
isStream := anthropicReq.Stream
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
|
||||
@@ -46,6 +48,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
|
||||
}
|
||||
|
||||
// Upstream always uses streaming (upstream may not support sync mode).
|
||||
// The client's original preference determines the response format.
|
||||
responsesReq.Stream = true
|
||||
isStream := true
|
||||
|
||||
// 2b. Handle BetaFastMode → service_tier: "priority"
|
||||
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
|
||||
responsesReq.ServiceTier = "priority"
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
@@ -72,7 +84,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
// OAuth codex transform forces stream=true upstream, so always use
|
||||
// the streaming response handler regardless of what the client asked.
|
||||
isStream = true
|
||||
@@ -94,6 +111,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// Override session_id with a deterministic UUID derived from the sticky
|
||||
// session key (buildUpstreamRequest may have set it to the raw value).
|
||||
if promptCacheKey != "" {
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
@@ -118,12 +141,13 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
|
||||
// 8. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
@@ -145,19 +169,37 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
|
||||
}
|
||||
}
|
||||
// Non-failover error: return Anthropic-formatted error to client
|
||||
return s.handleAnthropicErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
// 9. Handle normal response
|
||||
// Upstream is always streaming; choose response format based on client preference.
|
||||
var result *OpenAIForwardResult
|
||||
var handleErr error
|
||||
if isStream {
|
||||
if clientStream {
|
||||
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
} else {
|
||||
result, handleErr = s.handleAnthropicNonStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
||||
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
}
|
||||
|
||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||
if handleErr == nil && result != nil {
|
||||
if responsesReq.ServiceTier != "" {
|
||||
st := responsesReq.ServiceTier
|
||||
result.ServiceTier = &st
|
||||
}
|
||||
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||||
re := responsesReq.Reasoning.Effort
|
||||
result.ReasoningEffort = &re
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
@@ -227,9 +269,13 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// handleAnthropicNonStreamingResponse reads a Responses API JSON response,
|
||||
// converts it to Anthropic Messages format, and writes it to the client.
|
||||
func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||
// the upstream streaming response, finds the terminal event (response.completed
|
||||
// / response.incomplete / response.failed), converts the complete response to
|
||||
// Anthropic Messages JSON format, and writes it to the client.
|
||||
// This is used when the client requested stream=false but the upstream is always
|
||||
// streaming.
|
||||
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
@@ -238,29 +284,65 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var responsesResp apicompat.ResponsesResponse
|
||||
if err := json.Unmarshal(respBody, &responsesResp); err != nil {
|
||||
return nil, fmt.Errorf("parse responses response: %w", err)
|
||||
}
|
||||
|
||||
anthropicResp := apicompat.ResponsesToAnthropic(&responsesResp, originalModel)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
if responsesResp.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: responsesResp.Usage.InputTokens,
|
||||
OutputTokens: responsesResp.Usage.OutputTokens,
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if responsesResp.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = responsesResp.Usage.InputTokensDetails.CachedTokens
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||||
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||||
}
|
||||
|
||||
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
@@ -278,6 +360,9 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
// pattern to send Anthropic ping events during periods of upstream silence,
|
||||
// preventing proxy/client timeout disconnections.
|
||||
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
@@ -293,6 +378,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
state := apicompat.NewResponsesEventToAnthropicState()
|
||||
@@ -302,30 +388,41 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
firstChunk := true
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}
|
||||
payload := line[6:]
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// Parse the Responses SSE event
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
@@ -352,28 +449,36 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
// Client disconnected — return collected usage
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
// handleScanErr logs scanner errors if meaningful.
|
||||
handleScanErr := func(err error) {
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages stream: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
@@ -381,27 +486,94 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the Anthropic stream is properly terminated
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// Send Anthropic-format ping event
|
||||
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
|
||||
// Client disconnected
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeAnthropicError writes an error response in Anthropic Messages API format.
|
||||
|
||||
@@ -334,3 +334,225 @@ func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testin
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_gpt54_long_context",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 2000,
|
||||
},
|
||||
Model: "gpt-5.4-2026-03-05",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1014},
|
||||
User: &User{ID: 2014},
|
||||
Account: &Account{ID: 3014},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
expectedInput := 300000 * 2.5e-6 * 2.0
|
||||
expectedOutput := 2000 * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_priority",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1015},
|
||||
User: &User{ID: 2015},
|
||||
Account: &Account{ID: 3015},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "flex"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_flex",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1016},
|
||||
User: &User{ID: 2016},
|
||||
Account: &Account{ID: 3016},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
t.Run("fast maps to priority", func(t *testing.T) {
|
||||
got := normalizeOpenAIServiceTier(" fast ")
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
reasoning := "high"
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_model_override",
|
||||
BillingModel: "gpt-5.1-codex",
|
||||
Model: "gpt-5.1",
|
||||
ServiceTier: &serviceTier,
|
||||
ReasoningEffort: &reasoning,
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: 2 * time.Second,
|
||||
FirstTokenMs: func() *int { v := 120; return &v }(),
|
||||
},
|
||||
APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
UserAgent: "codex-cli/1.0",
|
||||
IPAddress: "127.0.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort)
|
||||
require.NotNil(t, usageRepo.lastLog.UserAgent)
|
||||
require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent)
|
||||
require.NotNil(t, usageRepo.lastLog.IPAddress)
|
||||
require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress)
|
||||
require.NotNil(t, usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, int64(11), *usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
subscription := &UserSubscription{ID: 99}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_subscription_billing",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
|
||||
User: &User{ID: 200},
|
||||
Account: &Account{ID: 300},
|
||||
Subscription: subscription,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType)
|
||||
require.NotNil(t, usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, 1, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.cfg.RunMode = config.RunModeSimple
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_simple_mode",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
@@ -213,6 +213,9 @@ type OpenAIForwardResult struct {
|
||||
// This is set by the Anthropic Messages conversion path where
|
||||
// the mapped upstream model differs from the client-facing model.
|
||||
BillingModel string
|
||||
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
|
||||
// Nil means the request did not specify a recognized tier.
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||
// Stored for usage records display; nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
@@ -908,6 +911,36 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||
if upstreamStatusCode != http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
|
||||
match := func(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lower, "an error occurred while processing your request") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lower, "you can retry your request") &&
|
||||
strings.Contains(lower, "help.openai.com") &&
|
||||
strings.Contains(lower, "request id")
|
||||
}
|
||||
|
||||
if match(upstreamMsg) {
|
||||
return true
|
||||
}
|
||||
if len(upstreamBody) == 0 {
|
||||
return false
|
||||
}
|
||||
if match(gjson.GetBytes(upstreamBody, "error.message").String()) {
|
||||
return true
|
||||
}
|
||||
return match(string(upstreamBody))
|
||||
}
|
||||
|
||||
// ExtractSessionID extracts the raw session ID from headers or body without hashing.
|
||||
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
|
||||
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
|
||||
@@ -1026,7 +1059,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -1099,7 +1132,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
@@ -1111,27 +1144,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
selected = fresh
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
if s.isBetterAccount(fresh, selected) {
|
||||
selected = fresh
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1309,13 +1335,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1359,13 +1389,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1377,11 +1411,15 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
// ============ Layer 3: Fallback wait ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||
for _, acc := range candidates {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
@@ -1418,11 +1456,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
|
||||
fresh := account
|
||||
if s.schedulerSnapshot != nil {
|
||||
current, err := s.getSchedulableAccount(ctx, account.ID)
|
||||
if err != nil || current == nil {
|
||||
return nil
|
||||
}
|
||||
fresh = current
|
||||
}
|
||||
|
||||
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
if s.schedulerSnapshot != nil {
|
||||
account, err = s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
} else {
|
||||
account, err = s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
if err != nil || account == nil {
|
||||
return account, err
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
@@ -1477,6 +1548,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||
if s.shouldFailoverUpstreamError(statusCode) {
|
||||
return true
|
||||
}
|
||||
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
@@ -1975,13 +2053,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
@@ -2002,7 +2080,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
@@ -2036,11 +2118,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
@@ -2195,6 +2279,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
ServiceTier: extractOpenAIServiceTierFromBody(body),
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
@@ -2815,7 +2900,11 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: body,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
@@ -3594,6 +3683,13 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
@@ -3628,7 +3724,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if result.BillingModel != "" {
|
||||
billingModel = result.BillingModel
|
||||
}
|
||||
cost, err := s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -3649,6 +3749,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -3871,6 +3972,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
|
||||
return updates
|
||||
}
|
||||
|
||||
func codexUsagePercentExhausted(value *float64) bool {
|
||||
return value != nil && *value >= 100-1e-9
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
|
||||
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
|
||||
if account == nil || !account.IsOpenAI() {
|
||||
return nil, false
|
||||
}
|
||||
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
|
||||
if resetAt == nil {
|
||||
return nil, false
|
||||
}
|
||||
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
|
||||
return account.RateLimitResetAt, false
|
||||
}
|
||||
account.RateLimitResetAt = resetAt
|
||||
return resetAt, true
|
||||
}
|
||||
|
||||
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
|
||||
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
|
||||
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
|
||||
return resetAt
|
||||
}
|
||||
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
return resetAt
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
@@ -3880,16 +4044,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
return
|
||||
}
|
||||
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
now := time.Now()
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, now)
|
||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}
|
||||
if resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -4047,6 +4217,40 @@ func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *s
|
||||
return &value
|
||||
}
|
||||
|
||||
func extractOpenAIServiceTier(reqBody map[string]any) *string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := reqBody["service_tier"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return normalizeOpenAIServiceTier(raw)
|
||||
}
|
||||
|
||||
func extractOpenAIServiceTierFromBody(body []byte) *string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String())
|
||||
}
|
||||
|
||||
func normalizeOpenAIServiceTier(raw string) *string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
if value == "fast" {
|
||||
value = "priority"
|
||||
}
|
||||
switch value {
|
||||
case "priority", "flex":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
||||
if c != nil {
|
||||
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
||||
|
||||
@@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T)
|
||||
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查"))
|
||||
}
|
||||
|
||||
func TestIsOpenAITransientProcessingError(t *testing.T) {
|
||||
require.True(t, isOpenAITransientProcessingError(
|
||||
http.StatusBadRequest,
|
||||
"An error occurred while processing your request.",
|
||||
nil,
|
||||
))
|
||||
|
||||
require.True(t, isOpenAITransientProcessingError(
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
[]byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`),
|
||||
))
|
||||
|
||||
require.False(t, isOpenAITransientProcessingError(
|
||||
http.StatusBadRequest,
|
||||
"Missing required parameter: 'instructions'",
|
||||
[]byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`),
|
||||
))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
@@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing
|
||||
require.True(t, logSink.ContainsField("request_body_size"))
|
||||
require.False(t, logSink.ContainsField("request_body_preview"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"x-request-id": []string{"rid-processing-400"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)),
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{ForceCodexCLI: false},
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1001,
|
||||
Name: "codex max套餐",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"api_key": "sk-test"},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`)
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
|
||||
require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应")
|
||||
}
|
||||
|
||||
@@ -671,7 +671,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"service_tier":"fast","input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_text.delta","delta":"h"}`,
|
||||
@@ -711,6 +711,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test
|
||||
require.GreaterOrEqual(t, time.Since(start), time.Duration(0))
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.GreaterOrEqual(t, *result.FirstTokenMs, 0)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) {
|
||||
@@ -777,7 +779,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"service_tier":"flex","max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||
@@ -803,8 +805,11 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "flex", *result.ServiceTier)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, originalBody, upstream.lastBody)
|
||||
require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String())
|
||||
|
||||
@@ -130,6 +130,7 @@ type OpenAITokenInfo struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
PlanType string `json:"plan_type,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
@@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
tokenInfo.PlanType = userInfo.PlanType
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
@@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
tokenInfo.PlanType = userInfo.PlanType
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
@@ -510,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
if tokenInfo.PlanType != "" {
|
||||
creds["plan_type"] = tokenInfo.PlanType
|
||||
}
|
||||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,13 @@ func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit,
|
||||
openAIStickyLegacyDualWriteTotal.Load()
|
||||
}
|
||||
|
||||
// DeriveSessionHashFromSeed computes the current-format sticky-session hash
|
||||
// from an arbitrary seed string.
|
||||
func DeriveSessionHashFromSeed(seed string) string {
|
||||
currentHash, _ := deriveOpenAISessionHashes(seed)
|
||||
return currentHash
|
||||
}
|
||||
|
||||
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
|
||||
normalized := strings.TrimSpace(sessionID)
|
||||
if normalized == "" {
|
||||
|
||||
@@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
account := Account{
|
||||
ID: 12,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
RateLimitResetAt: &rateLimitedUntil,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
|
||||
require.NoError(t, getErr)
|
||||
require.Zero(t, boundAccountID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
|
||||
@@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
|
||||
}
|
||||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
||||
}
|
||||
defer lease.Release()
|
||||
@@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "Upstream websocket error"
|
||||
@@ -2302,6 +2307,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
RequestID: responseID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: extractOpenAIServiceTier(reqBody),
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
@@ -2639,6 +2645,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
||||
}
|
||||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||||
return nil, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
@@ -2777,6 +2787,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
|
||||
@@ -2913,6 +2924,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
RequestID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: extractOpenAIServiceTierFromBody(payload),
|
||||
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
@@ -3604,6 +3616,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "OpenAI websocket prewarm error"
|
||||
@@ -3798,7 +3811,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -3867,6 +3880,36 @@ func classifyOpenAIWSAcquireError(err error) string {
|
||||
return "acquire_conn"
|
||||
}
|
||||
|
||||
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
msg := strings.ToLower(strings.TrimSpace(msgRaw))
|
||||
|
||||
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
|
||||
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
|
||||
return
|
||||
}
|
||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
}
|
||||
|
||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
@@ -3882,6 +3925,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
case "previous_response_not_found":
|
||||
return "previous_response_not_found", true
|
||||
}
|
||||
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return "upstream_rate_limited", false
|
||||
}
|
||||
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
|
||||
return "upgrade_required", true
|
||||
}
|
||||
@@ -3927,9 +3973,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||||
case strings.Contains(errType, "permission"),
|
||||
strings.Contains(code, "forbidden"):
|
||||
return http.StatusForbidden
|
||||
case strings.Contains(errType, "rate_limit"),
|
||||
strings.Contains(code, "rate_limit"),
|
||||
strings.Contains(code, "insufficient_quota"):
|
||||
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
|
||||
return http.StatusTooManyRequests
|
||||
default:
|
||||
return http.StatusBadGateway
|
||||
|
||||
@@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -424,6 +424,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.True(t, result.OpenAIWSMode)
|
||||
require.Equal(t, 2, result.Usage.InputTokens)
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
@@ -2593,7 +2595,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect
|
||||
require.NoError(t, err)
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false,"service_tier":"flex"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
// 立即关闭客户端,模拟客户端在 relay 期间断连。
|
||||
@@ -2611,6 +2613,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect
|
||||
require.Equal(t, "resp_ingress_disconnect", result.RequestID)
|
||||
require.Equal(t, 2, result.Usage.InputTokens)
|
||||
require.Equal(t, 1, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "flex", *result.ServiceTier)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到断连后的 turn 结果回调")
|
||||
}
|
||||
|
||||
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIWSRateLimitSignalRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCalls []time.Time
|
||||
updateExtra []map[string]any
|
||||
}
|
||||
|
||||
type openAICodexSnapshotAsyncRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCh chan map[string]any
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
type openAICodexExtraListRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtra = append(r.updateExtra, copied)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
if r.updateExtraCh != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCh <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
resetAt := time.Now().Add(2 * time.Hour).Unix()
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "rate_limit_exceeded",
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"resets_at": resetAt,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 501,
|
||||
Name: "openai-ws-rate-limit-event",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-codex-primary-used-percent", "100")
|
||||
w.Header().Set("x-codex-primary-reset-after-seconds", "7200")
|
||||
w.Header().Set("x-codex-primary-window-minutes", "10080")
|
||||
w.Header().Set("x-codex-secondary-used-percent", "3")
|
||||
w.Header().Set("x-codex-secondary-reset-after-seconds", "1800")
|
||||
w.Header().Set("x-codex-secondary-window-minutes", "300")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 502,
|
||||
Name: "openai-ws-rate-limit-handshake",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库")
|
||||
require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
resetAt := time.Now().Add(90 * time.Minute).Unix()
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`),
|
||||
},
|
||||
}
|
||||
captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10)))
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
account := Account{
|
||||
ID: 503,
|
||||
Name: "openai-ingress-rate-limit",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = clientConn.CloseNow() }()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(100),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(12),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
before := time.Now()
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCh:
|
||||
require.Equal(t, 100.0, updates["codex_7d_used_percent"])
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 100% 自动切换限流超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot)
|
||||
|
||||
select {
|
||||
case <-repo.updateExtraCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
t.Fatalf("unexpected rate limit reset at: %v", resetAt)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||
func ptrIntWS(v int) *int { return &v }
|
||||
|
||||
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
|
||||
resetAt := time.Now().Add(6 * 24 * time.Hour)
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
|
||||
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fresh)
|
||||
require.NotNil(t, fresh.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待旧快照补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
|
||||
resetAt := time.Now().Add(4 * 24 * time.Hour)
|
||||
repo := &openAICodexExtraListRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
|
||||
ID: 702,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}}},
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
require.NotNil(t, accounts[0].RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待列表补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
|
||||
}
|
||||
@@ -77,6 +77,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
|
||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||
@@ -178,6 +179,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@@ -225,6 +227,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
@@ -40,13 +40,17 @@ var (
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
InputCostPerTokenPriority float64 `json:"input_cost_per_token_priority"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
OutputCostPerTokenPriority float64 `json:"output_cost_per_token_priority"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
CacheReadInputTokenCostPriority float64 `json:"cache_read_input_token_cost_priority"`
|
||||
LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"`
|
||||
LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"`
|
||||
LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"`
|
||||
SupportsServiceTier bool `json:"supports_service_tier"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
@@ -62,10 +66,14 @@ type PricingRemoteClient interface {
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
type LiteLLMRawEntry struct {
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority"`
|
||||
SupportsServiceTier bool `json:"supports_service_tier"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
@@ -324,14 +332,21 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
LiteLLMProvider: entry.LiteLLMProvider,
|
||||
Mode: entry.Mode,
|
||||
SupportsPromptCaching: entry.SupportsPromptCaching,
|
||||
SupportsServiceTier: entry.SupportsServiceTier,
|
||||
}
|
||||
|
||||
if entry.InputCostPerToken != nil {
|
||||
pricing.InputCostPerToken = *entry.InputCostPerToken
|
||||
}
|
||||
if entry.InputCostPerTokenPriority != nil {
|
||||
pricing.InputCostPerTokenPriority = *entry.InputCostPerTokenPriority
|
||||
}
|
||||
if entry.OutputCostPerToken != nil {
|
||||
pricing.OutputCostPerToken = *entry.OutputCostPerToken
|
||||
}
|
||||
if entry.OutputCostPerTokenPriority != nil {
|
||||
pricing.OutputCostPerTokenPriority = *entry.OutputCostPerTokenPriority
|
||||
}
|
||||
if entry.CacheCreationInputTokenCost != nil {
|
||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||
}
|
||||
@@ -341,6 +356,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
if entry.CacheReadInputTokenCostPriority != nil {
|
||||
pricing.CacheReadInputTokenCostPriority = *entry.CacheReadInputTokenCostPriority
|
||||
}
|
||||
if entry.OutputCostPerImage != nil {
|
||||
pricing.OutputCostPerImage = *entry.OutputCostPerImage
|
||||
}
|
||||
|
||||
@@ -1,11 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParsePricingData_ParsesPriorityAndServiceTierFields(t *testing.T) {
|
||||
svc := &PricingService{}
|
||||
body := []byte(`{
|
||||
"gpt-5.4": {
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"input_cost_per_token_priority": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"output_cost_per_token_priority": 0.00003,
|
||||
"cache_creation_input_token_cost": 0.0000025,
|
||||
"cache_read_input_token_cost": 0.00000025,
|
||||
"cache_read_input_token_cost_priority": 0.0000005,
|
||||
"supports_service_tier": true,
|
||||
"supports_prompt_caching": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
}
|
||||
}`)
|
||||
|
||||
data, err := svc.parsePricingData(body)
|
||||
require.NoError(t, err)
|
||||
pricing := data["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-5, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 5e-7, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) {
|
||||
sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1}
|
||||
gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9}
|
||||
@@ -68,3 +97,64 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T)
|
||||
require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"gpt-5.4": map[string]any{
|
||||
"input_cost_per_token": 2.5e-6,
|
||||
"input_cost_per_token_priority": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
"output_cost_per_token_priority": 30e-6,
|
||||
"cache_read_input_token_cost": 0.25e-6,
|
||||
"cache_read_input_token_cost_priority": 0.5e-6,
|
||||
"supports_service_tier": true,
|
||||
"supports_prompt_caching": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PricingService{}
|
||||
pricingMap, err := svc.parsePricingData(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
pricing := pricingMap["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadInputTokenCost, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) {
|
||||
svc := &PricingService{}
|
||||
pricingData, err := svc.parsePricingData([]byte(`{
|
||||
"gpt-5.4": {
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"input_cost_per_token_priority": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"output_cost_per_token_priority": 0.00003,
|
||||
"cache_read_input_token_cost": 0.00000025,
|
||||
"cache_read_input_token_cost_priority": 0.0000005,
|
||||
"supports_service_tier": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
}
|
||||
}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
pricing := pricingData["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 0.0000025, pricing.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 0.000005, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.000015, pricing.OutputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 0.00003, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.00000025, pricing.CacheReadInputTokenCost, 1e-12)
|
||||
require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ type RateLimitService struct {
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||
type SuccessfulTestRecoveryResult struct {
|
||||
ClearedError bool
|
||||
ClearedRateLimit bool
|
||||
}
|
||||
|
||||
// AccountRecoveryOptions 控制账号恢复时的附加行为。
|
||||
type AccountRecoveryOptions struct {
|
||||
InvalidateToken bool
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
windowStart time.Time
|
||||
cachedAt time.Time
|
||||
@@ -87,6 +98,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return ErrorPolicySkipped
|
||||
}
|
||||
if account.IsPoolMode() {
|
||||
return ErrorPolicySkipped
|
||||
}
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return ErrorPolicyTempUnscheduled
|
||||
}
|
||||
@@ -96,9 +110,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
|
||||
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
|
||||
if account.IsPoolMode() && !customErrorCodesEnabled {
|
||||
slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
@@ -615,6 +636,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
@@ -878,6 +900,23 @@ func pickSooner(a, b *time.Time) *time.Time {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) {
|
||||
if s == nil || s.accountRepo == nil || account == nil || headers == nil {
|
||||
return
|
||||
}
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil {
|
||||
slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
@@ -1022,6 +1061,42 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SuccessfulTestRecoveryResult{}
|
||||
if account.Status == StatusError {
|
||||
if err := s.accountRepo.ClearError(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedError = true
|
||||
if options.InvalidateToken && s.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := s.tokenCacheInvalidator.InvalidateToken(ctx, account); invalidateErr != nil {
|
||||
slog.Warn("recover_account_state_invalidate_token_failed", "account_id", accountID, "error", invalidateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasRecoverableRuntimeState(account) {
|
||||
if err := s.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedRateLimit = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RecoverAccountAfterSuccessfulTest 将一次成功测试视为正常请求,
|
||||
// 按需恢复 error / rate-limit / overload / temp-unsched / model-rate-limit 等运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountAfterSuccessfulTest(ctx context.Context, accountID int64) (*SuccessfulTestRecoveryResult, error) {
|
||||
return s.RecoverAccountState(ctx, accountID, AccountRecoveryOptions{})
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
|
||||
return err
|
||||
@@ -1038,6 +1113,36 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasRecoverableRuntimeState(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.RateLimitedAt != nil || account.RateLimitResetAt != nil || account.OverloadUntil != nil || account.TempUnschedulableUntil != nil {
|
||||
return true
|
||||
}
|
||||
if len(account.Extra) == 0 {
|
||||
return false
|
||||
}
|
||||
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
|
||||
}
|
||||
|
||||
func hasNonEmptyMapValue(extra map[string]any, key string) bool {
|
||||
raw, ok := extra[key]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
return len(typed) > 0
|
||||
case map[string]string:
|
||||
return len(typed) > 0
|
||||
case []any:
|
||||
return len(typed) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||
now := time.Now().Unix()
|
||||
if s.tempUnschedCache != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,16 +14,34 @@ import (
|
||||
|
||||
type rateLimitClearRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
getByIDAccount *Account
|
||||
getByIDErr error
|
||||
getByIDCalls int
|
||||
clearErrorCalls int
|
||||
clearRateLimitCalls int
|
||||
clearAntigravityCalls int
|
||||
clearModelRateLimitCalls int
|
||||
clearTempUnschedCalls int
|
||||
clearErrorErr error
|
||||
clearRateLimitErr error
|
||||
clearAntigravityErr error
|
||||
clearModelRateLimitErr error
|
||||
clearTempUnschedulableErr error
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
r.getByIDCalls++
|
||||
if r.getByIDErr != nil {
|
||||
return nil, r.getByIDErr
|
||||
}
|
||||
return r.getByIDAccount, nil
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
r.clearErrorCalls++
|
||||
return r.clearErrorErr
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
r.clearRateLimitCalls++
|
||||
return r.clearRateLimitErr
|
||||
@@ -48,6 +67,11 @@ type tempUnschedCacheRecorder struct {
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type recoverTokenInvalidatorStub struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
||||
return nil
|
||||
}
|
||||
@@ -61,6 +85,11 @@ func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accoun
|
||||
return c.deleteErr
|
||||
}
|
||||
|
||||
func (s *recoverTokenInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
s.accounts = append(s.accounts, account)
|
||||
return s.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
@@ -170,3 +199,108 @@ func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) {
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLimitRelatedState(t *testing.T) {
|
||||
now := time.Now()
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 42,
|
||||
Status: StatusError,
|
||||
RateLimitedAt: &now,
|
||||
TempUnschedulableUntil: &now,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": now.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
"antigravity_quota_scopes": map[string]any{"gemini": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.True(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 1, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 0, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 0, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearTempUnschedCalls)
|
||||
require.Empty(t, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearErrorFailed(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 9,
|
||||
Status: StatusError,
|
||||
},
|
||||
clearErrorErr: errors.New("clear error failed"),
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountState_InvalidatesOAuthTokenOnErrorRecovery(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 21,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
},
|
||||
}
|
||||
invalidator := &recoverTokenInvalidatorStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc.SetTokenCacheInvalidator(invalidator)
|
||||
|
||||
result, err := svc.RecoverAccountState(context.Background(), 21, AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
require.Equal(t, int64(21), invalidator.accounts[0].ID)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -141,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type openAI429SnapshotRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedID int64
|
||||
updatedExtra map[string]any
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
||||
repo := &openAI429SnapshotRepo{}
|
||||
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
||||
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
svc.handle429(context.Background(), account, headers, nil)
|
||||
|
||||
if repo.rateLimitedID != account.ID {
|
||||
t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID)
|
||||
}
|
||||
if len(repo.updatedExtra) == 0 {
|
||||
t.Fatal("expected codex snapshot to be persisted on 429")
|
||||
}
|
||||
if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
|
||||
@@ -13,6 +13,7 @@ type ScheduledTestPlan struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover bool `json:"auto_recover"`
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
NextRunAt *time.Time `json:"next_run_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -17,6 +17,7 @@ type ScheduledTestRunnerService struct {
|
||||
planRepo ScheduledTestPlanRepository
|
||||
scheduledSvc *ScheduledTestService
|
||||
accountTestSvc *AccountTestService
|
||||
rateLimitSvc *RateLimitService
|
||||
cfg *config.Config
|
||||
|
||||
cron *cron.Cron
|
||||
@@ -29,12 +30,14 @@ func NewScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
return &ScheduledTestRunnerService{
|
||||
planRepo: planRepo,
|
||||
scheduledSvc: scheduledSvc,
|
||||
accountTestSvc: accountTestSvc,
|
||||
rateLimitSvc: rateLimitSvc,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -127,6 +130,11 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Auto-recover account if test succeeded and auto_recover is enabled.
|
||||
if result.Status == "success" && plan.AutoRecover {
|
||||
s.tryRecoverAccount(ctx, plan.AccountID, plan.ID)
|
||||
}
|
||||
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
|
||||
@@ -137,3 +145,26 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRecoverAccount attempts to recover an account from recoverable runtime state.
|
||||
func (s *ScheduledTestRunnerService) tryRecoverAccount(ctx context.Context, accountID int64, planID int64) {
|
||||
if s.rateLimitSvc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recovery, err := s.rateLimitSvc.RecoverAccountAfterSuccessfulTest(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover failed: %v", planID, err)
|
||||
return
|
||||
}
|
||||
if recovery == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if recovery.ClearedError {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d recovered from error status", planID, accountID)
|
||||
}
|
||||
if recovery.ClearedRateLimit {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d cleared rate-limit/runtime state", planID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1194,6 +1194,113 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
|
||||
return ver
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("get rectifier settings: %w", err)
|
||||
}
|
||||
if value == "" {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
|
||||
var settings RectifierSettings
|
||||
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// SetRectifierSettings 设置请求整流器配置
|
||||
func (s *SettingService) SetRectifierSettings(ctx context.Context, settings *RectifierSettings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal rectifier settings: %w", err)
|
||||
}
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyRectifierSettings, string(data))
|
||||
}
|
||||
|
||||
// IsSignatureRectifierEnabled 判断签名整流是否启用(总开关 && 签名子开关)
|
||||
func (s *SettingService) IsSignatureRectifierEnabled(ctx context.Context) bool {
|
||||
settings, err := s.GetRectifierSettings(ctx)
|
||||
if err != nil {
|
||||
return true // fail-open: 查询失败时默认启用
|
||||
}
|
||||
return settings.Enabled && settings.ThinkingSignatureEnabled
|
||||
}
|
||||
|
||||
// IsBudgetRectifierEnabled 判断 Budget 整流是否启用(总开关 && Budget 子开关)
|
||||
func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool {
|
||||
settings, err := s.GetRectifierSettings(ctx)
|
||||
if err != nil {
|
||||
return true // fail-open: 查询失败时默认启用
|
||||
}
|
||||
return settings.Enabled && settings.ThinkingBudgetEnabled
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return DefaultBetaPolicySettings(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("get beta policy settings: %w", err)
|
||||
}
|
||||
if value == "" {
|
||||
return DefaultBetaPolicySettings(), nil
|
||||
}
|
||||
|
||||
var settings BetaPolicySettings
|
||||
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||
return DefaultBetaPolicySettings(), nil
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// SetBetaPolicySettings 设置 Beta 策略配置
|
||||
func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
validActions := map[string]bool{
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
if rule.BetaToken == "" {
|
||||
return fmt.Errorf("rule[%d]: beta_token cannot be empty", i)
|
||||
}
|
||||
if !validActions[rule.Action] {
|
||||
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
|
||||
}
|
||||
if !validScopes[rule.Scope] {
|
||||
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal beta policy settings: %w", err)
|
||||
}
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
||||
}
|
||||
|
||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||
if settings == nil {
|
||||
|
||||
@@ -175,3 +175,61 @@ func DefaultStreamTimeoutSettings() *StreamTimeoutSettings {
|
||||
ThresholdWindowMinutes: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"` // 总开关
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流
|
||||
}
|
||||
|
||||
// DefaultRectifierSettings 返回默认的整流器配置(全部启用)
|
||||
func DefaultRectifierSettings() *RectifierSettings {
|
||||
return &RectifierSettings{
|
||||
Enabled: true,
|
||||
ThinkingSignatureEnabled: true,
|
||||
ThinkingBudgetEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Beta Policy 策略常量
|
||||
const (
|
||||
BetaPolicyActionPass = "pass" // 透传,不做任何处理
|
||||
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
|
||||
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
|
||||
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
)
|
||||
|
||||
// BetaPolicyRule 单条 Beta 策略规则
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"` // beta token 值
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置
|
||||
func DefaultBetaPolicySettings() *BetaPolicySettings {
|
||||
return &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "fast-mode-2026-02-01",
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
},
|
||||
{
|
||||
BetaToken: "context-1m-2025-08-07",
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,8 @@ type UsageLog struct {
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
|
||||
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
|
||||
@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||
svc := NewConcurrencyService(cache)
|
||||
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
||||
}
|
||||
if cfg != nil {
|
||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
@@ -287,9 +290,10 @@ func ProvideScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, rateLimitSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -23,10 +24,19 @@ import (
|
||||
|
||||
// Config paths
|
||||
const (
|
||||
ConfigFileName = "config.yaml"
|
||||
InstallLockFile = ".installed"
|
||||
ConfigFileName = "config.yaml"
|
||||
InstallLockFile = ".installed"
|
||||
defaultUserConcurrency = 5
|
||||
simpleModeAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func setupDefaultAdminConcurrency() int {
|
||||
if strings.EqualFold(strings.TrimSpace(os.Getenv("RUN_MODE")), config.RunModeSimple) {
|
||||
return simpleModeAdminConcurrency
|
||||
}
|
||||
return defaultUserConcurrency
|
||||
}
|
||||
|
||||
// GetDataDir returns the data directory for storing config and lock files.
|
||||
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
|
||||
func GetDataDir() string {
|
||||
@@ -390,7 +400,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
Concurrency: 5,
|
||||
Concurrency: setupDefaultAdminConcurrency(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
@@ -462,7 +472,7 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
APIKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
}{
|
||||
UserConcurrency: 5,
|
||||
UserConcurrency: defaultUserConcurrency,
|
||||
UserBalance: 0,
|
||||
APIKeyPrefix: "sk-",
|
||||
RateMultiplier: 1.0,
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package setup
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecideAdminBootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -49,3 +53,37 @@ func TestDecideAdminBootstrap(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupDefaultAdminConcurrency(t *testing.T) {
|
||||
t.Run("simple mode admin uses higher concurrency", func(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "simple")
|
||||
if got := setupDefaultAdminConcurrency(); got != simpleModeAdminConcurrency {
|
||||
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, simpleModeAdminConcurrency)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("standard mode keeps existing default", func(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "standard")
|
||||
if got := setupDefaultAdminConcurrency(); got != defaultUserConcurrency {
|
||||
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, defaultUserConcurrency)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "simple")
|
||||
t.Setenv("DATA_DIR", t.TempDir())
|
||||
|
||||
if err := writeConfigFile(&SetupConfig{}); err != nil {
|
||||
t.Fatalf("writeConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(GetConfigFilePath())
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), "user_concurrency: 5") {
|
||||
t.Fatalf("config missing default user concurrency, got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
|
||||
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// StubGatewayCache — service.GatewayCache 的空实现
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user