mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 00:40:22 +08:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
c28f691f32 |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -42,6 +42,6 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.11
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
@@ -162,7 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -229,7 +229,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -1402,7 +1402,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
|
||||
@@ -660,6 +660,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -1348,6 +1348,63 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -98,6 +98,19 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -125,9 +138,9 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
@@ -255,11 +268,19 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -475,6 +496,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -161,6 +161,13 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -193,8 +196,12 @@ type Account struct {
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -315,6 +322,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -972,33 +972,45 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -155,6 +155,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -212,6 +213,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -259,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -288,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -538,9 +560,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -602,6 +640,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -641,6 +680,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -1456,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -2207,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -445,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -631,7 +631,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@@ -648,7 +649,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@@ -663,8 +665,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
// Default effort applies (high → xhigh) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@@ -676,7 +679,93 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
// Default effort applies (high → xhigh) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// output_config.effort override tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "low", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
// No thinking field, but output_config.effort="medium" → creates reasoning.
|
||||
// medium→high after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "medium"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "xhigh".
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "high"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default xhigh regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default xhigh.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -733,3 +822,188 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Image content block conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_UserImageBlock(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"text","text":"What is in this image?"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "What is in this image?", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Read the screenshot"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_1","content":[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output (no image).
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "fc_toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "(empty)", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultMixed(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Describe the file"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_2","content":[
|
||||
{"type":"text","text":"File metadata: 800x600 PNG"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output.
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Check weather"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"call_1","content":[
|
||||
{"type":"text","text":"Sunny, 72°F"}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Text-only tool_result should produce a plain string.
|
||||
assert.Equal(t, "Sunny, 72°F", items[2].Output)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
// Should default to image/png when media_type is empty.
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
@@ -45,18 +45,16 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Convert thinking → reasoning.
|
||||
// generate_summary="auto" causes the upstream to emit reasoning_summary_text
|
||||
// streaming events; the include array only needs reasoning.encrypted_content
|
||||
// (already set above) for content continuity.
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "enabled":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "high", Summary: "auto"}
|
||||
case "adaptive":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "medium", Summary: "auto"}
|
||||
}
|
||||
// "disabled" or unknown → omit reasoning
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
@@ -169,7 +167,7 @@ func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, err
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items.
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
@@ -184,28 +182,46 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
text := extractAnthropicToolResultText(b)
|
||||
if text == "" {
|
||||
// OpenAI Responses API requires "output" field; use placeholder for empty results.
|
||||
text = "(empty)"
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: text,
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text blocks → user message.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
content, _ := json.Marshal(text)
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
@@ -290,26 +306,64 @@ func fromResponsesCallID(id string) string {
|
||||
return id
|
||||
}
|
||||
|
||||
// extractAnthropicToolResultText gets the text content from a tool_result block.
|
||||
func extractAnthropicToolResultText(b AnthropicContentBlock) string {
|
||||
if len(b.Content) == 0 {
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
return s
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err == nil {
|
||||
var parts []string
|
||||
for _, ib := range inner {
|
||||
if ib.Type == "text" && ib.Text != "" {
|
||||
parts = append(parts, ib.Text)
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
@@ -324,6 +378,23 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
|
||||
@@ -12,17 +12,23 @@ import "encoding/json"
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
@@ -47,6 +53,9 @@ type AnthropicContentBlock struct {
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -58,9 +67,16 @@ type AnthropicContentBlock struct {
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
@@ -146,6 +162,7 @@ type ResponsesRequest struct {
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
@@ -176,8 +193,9 @@ type ResponsesInputItem struct {
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{BetaFastMode}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -659,13 +659,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -925,6 +922,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1038,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1676,13 +1675,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
@@ -1704,7 +1737,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
@@ -1713,14 +1746,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
|
||||
@@ -165,8 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -476,8 +476,8 @@ func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -491,9 +491,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
|
||||
@@ -244,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -392,6 +393,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
|
||||
@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPoolModeRetryCount = 3
|
||||
maxPoolModeRetryCount = 10
|
||||
)
|
||||
|
||||
// GetPoolModeRetryCount 返回池模式同账号重试次数。
|
||||
// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。
|
||||
func (a *Account) GetPoolModeRetryCount() int {
|
||||
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
raw, ok := a.Credentials["pool_mode_retry_count"]
|
||||
if !ok || raw == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
count := parsePoolModeRetryCount(raw)
|
||||
if count < 0 {
|
||||
return 0
|
||||
}
|
||||
if count > maxPoolModeRetryCount {
|
||||
return maxPoolModeRetryCount
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePoolModeRetryCount(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
|
||||
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
|
||||
func isPoolModeRetryableStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -1134,33 +1203,97 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
return a.getExtraFloat64("quota_limit")
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
}
|
||||
|
||||
// GetQuotaDailyUsed 获取当日已用额度(美元)
|
||||
func (a *Account) GetQuotaDailyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_daily_used")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaWeeklyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_limit")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
|
||||
func (a *Account) GetQuotaWeeklyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_used")
|
||||
}
|
||||
|
||||
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
|
||||
func (a *Account) getExtraFloat64(key string) float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
|
||||
func (a *Account) getExtraTime(key string) time.Time {
|
||||
if a.Extra == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||
return t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
}
|
||||
|
||||
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期
|
||||
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true // 从未使用过,视为过期(下次 increment 会初始化)
|
||||
}
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
|
||||
117
backend/internal/service/account_pool_mode_test.go
Normal file
117
backend/internal/service/account_pool_mode_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,9 +68,9 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
|
||||
@@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
|
||||
102
backend/internal/service/account_test_service_openai_test.go
Normal file
102
backend/internal/service/account_test_service_openai_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
@@ -367,7 +368,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
@@ -409,6 +410,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if usage == nil {
|
||||
return true
|
||||
}
|
||||
if usage.FiveHour == nil || usage.SevenDay == nil {
|
||||
return true
|
||||
}
|
||||
if account.IsRateLimited() {
|
||||
return true
|
||||
}
|
||||
return isOpenAICodexSnapshotStale(account, now)
|
||||
}
|
||||
|
||||
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
|
||||
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return false
|
||||
}
|
||||
if account.Extra == nil {
|
||||
return true
|
||||
}
|
||||
raw, ok := account.Extra["codex_usage_updated_at"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, err := parseTime(fmt.Sprint(raw))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(ts) >= openAIProbeCacheTTL
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
@@ -478,20 +513,34 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
68
backend/internal/service/account_usage_service_test.go
Normal file
68
backend/internal/service/account_usage_service_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
FiveHour: &UsageProgress{Utilization: 0},
|
||||
SevenDay: &UsageProgress{Utilization: 0},
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
|
||||
t.Fatal("expected rate-limited account to force codex snapshot refresh")
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
|
||||
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
|
||||
t.Fatal("expected missing 5h snapshot to require refresh")
|
||||
}
|
||||
|
||||
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"codex_usage_updated_at": staleAt,
|
||||
},
|
||||
}, usage, now) {
|
||||
t.Fatal("expected stale ws snapshot to trigger refresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if got := updates["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
@@ -1349,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range accounts {
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
@@ -1484,9 +1488,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
// 保留 quota_used,防止编辑账号时意外重置配额用量
|
||||
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
|
||||
input.Extra["quota_used"] = oldQuotaUsed
|
||||
// 保留配额用量字段,防止编辑账号时意外重置
|
||||
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
|
||||
if v, ok := account.Extra[key]; ok {
|
||||
input.Extra[key] = v
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
@@ -1717,16 +1723,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err := s.accountRepo.ClearError(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
|
||||
@@ -1384,7 +1384,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody, maxBytes := s.getLogConfig()
|
||||
@@ -1517,6 +1517,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
|
||||
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
|
||||
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: s.getUpstreamErrorDetail(respBody),
|
||||
})
|
||||
|
||||
// 修正 claudeReq 的 thinking 参数(adaptive 模式不修正)
|
||||
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
|
||||
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: BudgetRectifyBudgetTokens,
|
||||
}
|
||||
if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens {
|
||||
retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
|
||||
if txErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: action,
|
||||
body: retryGeminiBody,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
})
|
||||
if retryErr == nil {
|
||||
retryResp := retryResult.resp
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||||
|
||||
@@ -78,11 +78,11 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
|
||||
@@ -43,16 +43,19 @@ type BillingCache interface {
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -61,6 +64,28 @@ const (
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
func normalizeBillingServiceTier(serviceTier string) string {
|
||||
return strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
}
|
||||
|
||||
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
|
||||
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
|
||||
return false
|
||||
}
|
||||
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
|
||||
}
|
||||
|
||||
func serviceTierCostMultiplier(serviceTier string) float64 {
|
||||
switch normalizeBillingServiceTier(serviceTier) {
|
||||
case "priority":
|
||||
return 2.0
|
||||
case "flex":
|
||||
return 0.5
|
||||
default:
|
||||
return 1.0
|
||||
}
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
@@ -173,30 +198,60 @@ func (s *BillingService) initFallbackPricing() {
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
CacheReadPricePerTokenPriority: 0.25e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
InputPricePerTokenPriority: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
InputPricePerTokenPriority: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
CacheReadPricePerTokenPriority: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
@@ -241,6 +296,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
return s.fallbackPrices["gpt-5.2"]
|
||||
case "gpt-5.2-codex":
|
||||
return s.fallbackPrices["gpt-5.2-codex"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
@@ -269,16 +328,19 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
@@ -295,6 +357,10 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -303,6 +369,21 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
if usePriorityServiceTierPricing(serviceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(serviceTier)
|
||||
}
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
@@ -329,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
|
||||
@@ -522,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) {
|
||||
require.False(t, math.IsNaN(cost.TotalCost))
|
||||
require.False(t, math.IsInf(cost.TotalCost, 0))
|
||||
}
|
||||
|
||||
func TestServiceTierCostMultiplier(t *testing.T) {
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
|
||||
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
|
||||
|
||||
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
|
||||
pricingSvc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.4": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
InputCostPerTokenPriority: 5e-6,
|
||||
OutputCostPerToken: 15e-6,
|
||||
OutputCostPerTokenPriority: 30e-6,
|
||||
CacheCreationInputTokenCost: 2.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
CacheReadInputTokenCostPriority: 0.5e-6,
|
||||
LongContextInputTokenThreshold: 272000,
|
||||
LongContextInputCostMultiplier: 2.0,
|
||||
LongContextOutputCostMultiplier: 1.5,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewBillingService(&config.Config{}, pricingSvc)
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52Codex)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"custom-no-priority": {
|
||||
InputCostPerToken: 1e-6,
|
||||
OutputCostPerToken: 2e-6,
|
||||
CacheCreationInputTokenCost: 0.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"dynamic-tier-model": {
|
||||
InputCostPerToken: 1e-6,
|
||||
InputCostPerTokenPriority: 2e-6,
|
||||
OutputCostPerToken: 3e-6,
|
||||
OutputCostPerTokenPriority: 6e-6,
|
||||
CacheCreationInputTokenCost: 4e-6,
|
||||
CacheCreationInputTokenCostAbove1hr: 5e-6,
|
||||
CacheReadInputTokenCost: 7e-7,
|
||||
CacheReadInputTokenCostPriority: 8e-7,
|
||||
LongContextInputTokenThreshold: 999,
|
||||
LongContextInputCostMultiplier: 1.5,
|
||||
LongContextOutputCostMultiplier: 1.25,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pricing, err := svc.GetModelPricing("dynamic-tier-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
require.True(t, pricing.SupportsCacheBreakdown)
|
||||
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 999, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -175,6 +175,13 @@ const (
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
|
||||
// =========================
|
||||
// Request Rectifier (请求整流器)
|
||||
// =========================
|
||||
|
||||
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
|
||||
SettingKeyRectifierSettings = "rectifier_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
{
|
||||
name: "pool_mode_custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 7,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401), float64(403)},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "pool_mode_without_custom_error_codes_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 8,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
|
||||
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 30,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.False(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
|
||||
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 31,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(401)},
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "DroppedBetas removes both context-1m and fast-mode",
|
||||
name: "DroppedBetas removes fast-mode only",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||
tokens: claude.DroppedBetas,
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -117,21 +117,21 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||
drop := droppedBetaSet()
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got)
|
||||
require.Contains(t, got, "context-1m-2025-08-07")
|
||||
require.NotContains(t, got, "fast-mode-2026-02-01")
|
||||
}
|
||||
|
||||
func TestDroppedBetaSet(t *testing.T) {
|
||||
// Base set contains DroppedBetas
|
||||
base := droppedBetaSet()
|
||||
require.Contains(t, base, claude.BetaContext1M)
|
||||
require.NotContains(t, base, claude.BetaContext1M)
|
||||
require.Contains(t, base, claude.BetaFastMode)
|
||||
require.Len(t, base, len(claude.DroppedBetas))
|
||||
|
||||
// With extra tokens
|
||||
extended := droppedBetaSet(claude.BetaClaudeCode)
|
||||
require.Contains(t, extended, claude.BetaContext1M)
|
||||
require.NotContains(t, extended, claude.BetaContext1M)
|
||||
require.Contains(t, extended, claude.BetaFastMode)
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
@@ -148,6 +148,32 @@ func TestBuildBetaTokenSet(t *testing.T) {
|
||||
require.Empty(t, empty)
|
||||
}
|
||||
|
||||
func TestContainsBetaToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
token string
|
||||
want bool
|
||||
}{
|
||||
{"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true},
|
||||
{"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
|
||||
{"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false},
|
||||
{"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
|
||||
{"empty header", "", "fast-mode-2026-02-01", false},
|
||||
{"empty token", "fast-mode-2026-02-01", "", false},
|
||||
{"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := containsBetaToken(tt.header, tt.token)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
|
||||
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
got := stripBetaTokensWithSet(header, map[string]struct{}{})
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -258,6 +259,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
if !hasEmptyContent && !containsThinkingBlocks {
|
||||
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
|
||||
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
return out
|
||||
}
|
||||
return body
|
||||
@@ -395,6 +397,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
} else {
|
||||
return body
|
||||
}
|
||||
// Removing "thinking" makes any context_management strategy that requires it invalid
|
||||
// (e.g. clear_thinking_20251015). Strip those entries so the retry request does not
|
||||
// receive a 400 "strategy requires thinking to be enabled or adaptive".
|
||||
out = removeThinkingDependentContextStrategies(out)
|
||||
}
|
||||
if modified {
|
||||
msgsBytes, err := json.Marshal(messages)
|
||||
@@ -409,6 +415,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// removeThinkingDependentContextStrategies 从 context_management.edits 中移除
|
||||
// 需要 thinking 启用的策略(如 clear_thinking_20251015)。
|
||||
// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回
|
||||
// "strategy requires thinking to be enabled or adaptive"。
|
||||
func removeThinkingDependentContextStrategies(body []byte) []byte {
|
||||
jsonStr := *(*string)(unsafe.Pointer(&body))
|
||||
editsRes := gjson.Get(jsonStr, "context_management.edits")
|
||||
if !editsRes.Exists() || !editsRes.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
var filtered []json.RawMessage
|
||||
hasRemoved := false
|
||||
editsRes.ForEach(func(_, v gjson.Result) bool {
|
||||
if v.Get("type").String() == "clear_thinking_20251015" {
|
||||
hasRemoved = true
|
||||
return true
|
||||
}
|
||||
filtered = append(filtered, json.RawMessage(v.Raw))
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasRemoved {
|
||||
return body
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
filteredBytes, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil {
|
||||
return b
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||
// signature/thought_signature validation issues involving tool blocks.
|
||||
//
|
||||
@@ -444,6 +493,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
// Remove context_management strategies that require thinking to be enabled
|
||||
// (e.g. clear_thinking_20251015), otherwise upstream returns 400.
|
||||
if cm, ok := req["context_management"].(map[string]any); ok {
|
||||
if edits, ok := cm["edits"].([]any); ok {
|
||||
filtered := make([]any, 0, len(edits))
|
||||
for _, edit := range edits {
|
||||
if editMap, ok := edit.(map[string]any); ok {
|
||||
if editMap["type"] == "clear_thinking_20251015" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, edit)
|
||||
}
|
||||
if len(filtered) != len(edits) {
|
||||
if len(filtered) == 0 {
|
||||
delete(cm, "edits")
|
||||
} else {
|
||||
cm["edits"] = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
@@ -675,3 +746,90 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Thinking Budget Rectifier
|
||||
// =========================
|
||||
|
||||
const (
|
||||
// BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying.
|
||||
BudgetRectifyBudgetTokens = 32000
|
||||
// BudgetRectifyMaxTokens is the max_tokens value to set when rectifying.
|
||||
BudgetRectifyMaxTokens = 64000
|
||||
// BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens.
|
||||
BudgetRectifyMinMaxTokens = 32001
|
||||
)
|
||||
|
||||
// isThinkingBudgetConstraintError detects whether an upstream error message indicates
|
||||
// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024").
|
||||
// Matches three conditions (all must be true):
|
||||
// 1. Contains "budget_tokens" or "budget tokens"
|
||||
// 2. Contains "thinking"
|
||||
// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be")
|
||||
func isThinkingBudgetConstraintError(errMsg string) bool {
|
||||
m := strings.ToLower(errMsg)
|
||||
|
||||
// Condition 1: budget_tokens or budget tokens
|
||||
hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens")
|
||||
if !hasBudget {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 2: thinking
|
||||
if !strings.Contains(m, "thinking") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Condition 3: constraint indicator
|
||||
if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(m, "1024") && strings.Contains(m, "input should be") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors.
|
||||
// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive),
|
||||
// and ensures max_tokens >= 32001.
|
||||
// Returns (modified body, true) if changes were applied, or (original body, false) if not.
|
||||
func RectifyThinkingBudget(body []byte) ([]byte, bool) {
|
||||
// If thinking type is "adaptive", skip rectification entirely
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType == "adaptive" {
|
||||
return body, false
|
||||
}
|
||||
|
||||
modified := body
|
||||
changed := false
|
||||
|
||||
// Set thinking.type = "enabled"
|
||||
if thinkingType != "enabled" {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set thinking.budget_tokens = 32000
|
||||
currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int()
|
||||
if currentBudget != BudgetRectifyBudgetTokens {
|
||||
if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure max_tokens >= BudgetRectifyMinMaxTokens
|
||||
maxTokens := gjson.GetBytes(modified, "max_tokens").Int()
|
||||
if maxTokens < int64(BudgetRectifyMinMaxTokens) {
|
||||
if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil {
|
||||
modified = result
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified, changed
|
||||
}
|
||||
|
||||
@@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
|
||||
// ============ Group 6b: context_management.edits 清理测试 ============
|
||||
|
||||
// removeThinkingDependentContextStrategies — 边界用例
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) {
|
||||
input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "无 context_management 字段时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 为空数组时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键")
|
||||
}
|
||||
|
||||
func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) {
|
||||
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`)
|
||||
out := removeThinkingDependentContextStrategies(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留其他条目")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_strategy", edit0["type"])
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) {
|
||||
// 快速路径:messages 中无 thinking 块,仅有顶层 thinking 字段
|
||||
// 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasEdits := cm["edits"]
|
||||
require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除,edits 键应被删除")
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) {
|
||||
// 完整路径:messages 中有 thinking 块(非 fast path)
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"some thought","signature":"sig"},
|
||||
{"type":"text","text":"Answer"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 keep_this")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep_this", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) {
|
||||
// 无 context_management 时不应报错,且 thinking 正常被移除
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
_, hasCM := req["context_management"]
|
||||
require.False(t, hasCM)
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"thought","signature":"sig"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking, "顶层 thinking 应被移除")
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
if rawEdits, hasEdits := cm["edits"]; hasEdits {
|
||||
edits, ok := rawEdits.([]any)
|
||||
require.True(t, ok)
|
||||
for _, e := range edits {
|
||||
em, ok := e.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 other_edit")
|
||||
edit0, ok := edits[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "other_edit", edit0["type"])
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) {
|
||||
// 没有顶层 thinking 字段时,context_management 不应被修改
|
||||
input := []byte(`{
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"t","signature":"s"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
cm, ok := req["context_management"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
edits, ok := cm["edits"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改")
|
||||
}
|
||||
|
||||
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
|
||||
|
||||
// Task 7.1 — 类型校验边界测试
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 40 * 1024 * 1024
|
||||
defaultMaxLineSize = 500 * 1024 * 1024
|
||||
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
|
||||
// to match real Claude CLI traffic as closely as possible. When we need a visual
|
||||
// separator between system blocks, we add "\n\n" at concatenation time.
|
||||
@@ -526,6 +526,7 @@ type GatewayService struct {
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
settingService *SettingService
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
@@ -553,6 +554,7 @@ func NewGatewayService(
|
||||
sessionLimitCache SessionLimitCache,
|
||||
rpmCache RPMCache,
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -579,6 +581,7 @@ func NewGatewayService(
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
@@ -994,6 +997,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
||||
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
|
||||
}
|
||||
|
||||
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
||||
func GenerateSessionUUID(seed string) string {
|
||||
return generateSessionUUID(seed)
|
||||
}
|
||||
|
||||
func generateSessionUUID(seed string) string {
|
||||
if seed == "" {
|
||||
return uuid.NewString()
|
||||
@@ -3328,10 +3336,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
||||
}
|
||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
||||
return true
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@@ -4069,7 +4073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@@ -4186,7 +4190,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
// 不是签名错误(或整流器已关闭),继续检查 budget 约束
|
||||
errMsg := extractUpstreamErrorMessage(respBody)
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = budgetRetryResp
|
||||
break
|
||||
}
|
||||
if budgetRetryResp != nil && budgetRetryResp.Body != nil {
|
||||
_ = budgetRetryResp.Body.Close()
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
@@ -4278,7 +4320,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4308,7 +4354,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
@@ -4543,7 +4593,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -4573,7 +4627,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -5288,6 +5346,19 @@ func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
return m
|
||||
}
|
||||
|
||||
// containsBetaToken checks if a comma-separated header value contains the given token.
|
||||
func containsBetaToken(header, token string) bool {
|
||||
if header == "" || token == "" {
|
||||
return false
|
||||
}
|
||||
for _, p := range strings.Split(header, ",") {
|
||||
if strings.TrimSpace(p) == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(tokens))
|
||||
for _, t := range tokens {
|
||||
@@ -6437,7 +6508,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -6928,7 +6999,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
@@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -12,6 +12,78 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAISnapshotCacheStub struct {
|
||||
SchedulerCache
|
||||
snapshotAccounts []*Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
|
||||
if len(s.snapshotAccounts) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
out := make([]*Account, 0, len(s.snapshotAccounts))
|
||||
for _, account := range s.snapshotAccounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
cloned := *account
|
||||
out = append(out, &cloned)
|
||||
}
|
||||
return out, true, nil
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.accountsByID == nil {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.accountsByID[accountID]
|
||||
if account == nil {
|
||||
return nil, nil
|
||||
}
|
||||
cloned := *account
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10101)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(31002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10102)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -38,7 +39,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
originalModel := anthropicReq.Model
|
||||
isStream := anthropicReq.Stream
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
|
||||
@@ -46,6 +47,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
|
||||
}
|
||||
|
||||
// Upstream always uses streaming (upstream may not support sync mode).
|
||||
// The client's original preference determines the response format.
|
||||
responsesReq.Stream = true
|
||||
isStream := true
|
||||
|
||||
// 2b. Handle BetaFastMode → service_tier: "priority"
|
||||
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
|
||||
responsesReq.ServiceTier = "priority"
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
@@ -72,7 +83,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
// OAuth codex transform forces stream=true upstream, so always use
|
||||
// the streaming response handler regardless of what the client asked.
|
||||
isStream = true
|
||||
@@ -94,6 +110,12 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// Override session_id with a deterministic UUID derived from the sticky
|
||||
// session key (buildUpstreamRequest may have set it to the raw value).
|
||||
if promptCacheKey != "" {
|
||||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||||
}
|
||||
|
||||
// 7. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
@@ -152,12 +174,26 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 9. Handle normal response
|
||||
// Upstream is always streaming; choose response format based on client preference.
|
||||
var result *OpenAIForwardResult
|
||||
var handleErr error
|
||||
if isStream {
|
||||
if clientStream {
|
||||
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
} else {
|
||||
result, handleErr = s.handleAnthropicNonStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
||||
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
}
|
||||
|
||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||
if handleErr == nil && result != nil {
|
||||
if responsesReq.ServiceTier != "" {
|
||||
st := responsesReq.ServiceTier
|
||||
result.ServiceTier = &st
|
||||
}
|
||||
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||||
re := responsesReq.Reasoning.Effort
|
||||
result.ReasoningEffort = &re
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
@@ -227,9 +263,13 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// handleAnthropicNonStreamingResponse reads a Responses API JSON response,
|
||||
// converts it to Anthropic Messages format, and writes it to the client.
|
||||
func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||
// the upstream streaming response, finds the terminal event (response.completed
|
||||
// / response.incomplete / response.failed), converts the complete response to
|
||||
// Anthropic Messages JSON format, and writes it to the client.
|
||||
// This is used when the client requested stream=false but the upstream is always
|
||||
// streaming.
|
||||
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
@@ -238,29 +278,61 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
|
||||
var responsesResp apicompat.ResponsesResponse
|
||||
if err := json.Unmarshal(respBody, &responsesResp); err != nil {
|
||||
return nil, fmt.Errorf("parse responses response: %w", err)
|
||||
}
|
||||
|
||||
anthropicResp := apicompat.ResponsesToAnthropic(&responsesResp, originalModel)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
if responsesResp.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: responsesResp.Usage.InputTokens,
|
||||
OutputTokens: responsesResp.Usage.OutputTokens,
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if responsesResp.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = responsesResp.Usage.InputTokensDetails.CachedTokens
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||||
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||||
}
|
||||
|
||||
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
@@ -278,6 +350,9 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
// pattern to send Anthropic ping events during periods of upstream silence,
|
||||
// preventing proxy/client timeout disconnections.
|
||||
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
@@ -293,6 +368,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
state := apicompat.NewResponsesEventToAnthropicState()
|
||||
@@ -304,28 +380,35 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}
|
||||
payload := line[6:]
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// Parse the Responses SSE event
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
@@ -352,28 +435,36 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
// Client disconnected — return collected usage
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
// handleScanErr logs scanner errors if meaningful.
|
||||
handleScanErr := func(err error) {
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages stream: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
@@ -381,27 +472,94 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the Anthropic stream is properly terminated
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
// Send Anthropic-format ping event
|
||||
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
|
||||
// Client disconnected
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeAnthropicError writes an error response in Anthropic Messages API format.
|
||||
|
||||
@@ -334,3 +334,225 @@ func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testin
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_gpt54_long_context",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 2000,
|
||||
},
|
||||
Model: "gpt-5.4-2026-03-05",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1014},
|
||||
User: &User{ID: 2014},
|
||||
Account: &Account{ID: 3014},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
expectedInput := 300000 * 2.5e-6 * 2.0
|
||||
expectedOutput := 2000 * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_priority",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1015},
|
||||
User: &User{ID: 2015},
|
||||
Account: &Account{ID: 3015},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "flex"
|
||||
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_service_tier_flex",
|
||||
ServiceTier: &serviceTier,
|
||||
Usage: usage,
|
||||
Model: "gpt-5.4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1016},
|
||||
User: &User{ID: 2016},
|
||||
Account: &Account{ID: 3016},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
|
||||
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0)
|
||||
require.NoError(t, calcErr)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
t.Run("fast maps to priority", func(t *testing.T) {
|
||||
got := normalizeOpenAIServiceTier(" fast ")
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
serviceTier := "priority"
|
||||
reasoning := "high"
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_model_override",
|
||||
BillingModel: "gpt-5.1-codex",
|
||||
Model: "gpt-5.1",
|
||||
ServiceTier: &serviceTier,
|
||||
ReasoningEffort: &reasoning,
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
},
|
||||
Duration: 2 * time.Second,
|
||||
FirstTokenMs: func() *int { v := 120; return &v }(),
|
||||
},
|
||||
APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
UserAgent: "codex-cli/1.0",
|
||||
IPAddress: "127.0.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
|
||||
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort)
|
||||
require.NotNil(t, usageRepo.lastLog.UserAgent)
|
||||
require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent)
|
||||
require.NotNil(t, usageRepo.lastLog.IPAddress)
|
||||
require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress)
|
||||
require.NotNil(t, usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, int64(11), *usageRepo.lastLog.GroupID)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
subscription := &UserSubscription{ID: 99}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_subscription_billing",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
|
||||
User: &User{ID: 200},
|
||||
Account: &Account{ID: 300},
|
||||
Subscription: subscription,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType)
|
||||
require.NotNil(t, usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID)
|
||||
require.Equal(t, 1, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.cfg.RunMode = config.RunModeSimple
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_simple_mode",
|
||||
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
@@ -213,6 +213,9 @@ type OpenAIForwardResult struct {
|
||||
// This is set by the Anthropic Messages conversion path where
|
||||
// the mapped upstream model differs from the client-facing model.
|
||||
BillingModel string
|
||||
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
|
||||
// Nil means the request did not specify a recognized tier.
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||
// Stored for usage records display; nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
@@ -1026,7 +1029,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -1099,7 +1102,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
@@ -1111,27 +1114,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
selected = fresh
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
if s.isBetterAccount(fresh, selected) {
|
||||
selected = fresh
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1309,13 +1305,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1359,13 +1359,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1377,11 +1381,15 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
// ============ Layer 3: Fallback wait ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||
for _, acc := range candidates {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
@@ -1418,11 +1426,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
|
||||
fresh := account
|
||||
if s.schedulerSnapshot != nil {
|
||||
current, err := s.getSchedulableAccount(ctx, account.ID)
|
||||
if err != nil || current == nil {
|
||||
return nil
|
||||
}
|
||||
fresh = current
|
||||
}
|
||||
|
||||
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
if s.schedulerSnapshot != nil {
|
||||
account, err = s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
} else {
|
||||
account, err = s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
if err != nil || account == nil {
|
||||
return account, err
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
@@ -2002,7 +2043,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
@@ -2036,11 +2081,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
@@ -2195,6 +2242,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
ServiceTier: extractOpenAIServiceTierFromBody(body),
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
@@ -2815,7 +2863,11 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: body,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
@@ -3594,6 +3646,13 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
@@ -3628,7 +3687,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if result.BillingModel != "" {
|
||||
billingModel = result.BillingModel
|
||||
}
|
||||
cost, err := s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -3649,6 +3712,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -3871,6 +3935,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
|
||||
return updates
|
||||
}
|
||||
|
||||
func codexUsagePercentExhausted(value *float64) bool {
|
||||
return value != nil && *value >= 100-1e-9
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
|
||||
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
|
||||
if account == nil || !account.IsOpenAI() {
|
||||
return nil, false
|
||||
}
|
||||
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
|
||||
if resetAt == nil {
|
||||
return nil, false
|
||||
}
|
||||
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
|
||||
return account.RateLimitResetAt, false
|
||||
}
|
||||
account.RateLimitResetAt = resetAt
|
||||
return resetAt, true
|
||||
}
|
||||
|
||||
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
|
||||
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
|
||||
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
|
||||
return resetAt
|
||||
}
|
||||
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
return resetAt
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
@@ -3880,16 +4007,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
return
|
||||
}
|
||||
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
now := time.Now()
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, now)
|
||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}
|
||||
if resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -4047,6 +4180,40 @@ func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *s
|
||||
return &value
|
||||
}
|
||||
|
||||
func extractOpenAIServiceTier(reqBody map[string]any) *string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := reqBody["service_tier"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return normalizeOpenAIServiceTier(raw)
|
||||
}
|
||||
|
||||
func extractOpenAIServiceTierFromBody(body []byte) *string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String())
|
||||
}
|
||||
|
||||
func normalizeOpenAIServiceTier(raw string) *string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
if value == "fast" {
|
||||
value = "priority"
|
||||
}
|
||||
switch value {
|
||||
case "priority", "flex":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
||||
if c != nil {
|
||||
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
||||
|
||||
@@ -671,7 +671,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"service_tier":"fast","input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_text.delta","delta":"h"}`,
|
||||
@@ -711,6 +711,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test
|
||||
require.GreaterOrEqual(t, time.Since(start), time.Duration(0))
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.GreaterOrEqual(t, *result.FirstTokenMs, 0)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) {
|
||||
@@ -777,7 +779,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"service_tier":"flex","max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||
@@ -803,8 +805,11 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "flex", *result.ServiceTier)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, originalBody, upstream.lastBody)
|
||||
require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String())
|
||||
|
||||
@@ -29,6 +29,13 @@ func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit,
|
||||
openAIStickyLegacyDualWriteTotal.Load()
|
||||
}
|
||||
|
||||
// DeriveSessionHashFromSeed computes the current-format sticky-session hash
|
||||
// from an arbitrary seed string.
|
||||
func DeriveSessionHashFromSeed(seed string) string {
|
||||
currentHash, _ := deriveOpenAISessionHashes(seed)
|
||||
return currentHash
|
||||
}
|
||||
|
||||
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
|
||||
normalized := strings.TrimSpace(sessionID)
|
||||
if normalized == "" {
|
||||
|
||||
@@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
account := Account{
|
||||
ID: 12,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
RateLimitResetAt: &rateLimitedUntil,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
|
||||
require.NoError(t, getErr)
|
||||
require.Zero(t, boundAccountID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
|
||||
@@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
|
||||
}
|
||||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
||||
}
|
||||
defer lease.Release()
|
||||
@@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "Upstream websocket error"
|
||||
@@ -2302,6 +2307,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
RequestID: responseID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: extractOpenAIServiceTier(reqBody),
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
@@ -2639,6 +2645,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
||||
}
|
||||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||||
return nil, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
@@ -2777,6 +2787,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
|
||||
@@ -2913,6 +2924,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
RequestID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: extractOpenAIServiceTierFromBody(payload),
|
||||
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
@@ -3604,6 +3616,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "OpenAI websocket prewarm error"
|
||||
@@ -3798,7 +3811,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -3867,6 +3880,36 @@ func classifyOpenAIWSAcquireError(err error) string {
|
||||
return "acquire_conn"
|
||||
}
|
||||
|
||||
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
msg := strings.ToLower(strings.TrimSpace(msgRaw))
|
||||
|
||||
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
|
||||
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
|
||||
return
|
||||
}
|
||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
}
|
||||
|
||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
@@ -3882,6 +3925,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
case "previous_response_not_found":
|
||||
return "previous_response_not_found", true
|
||||
}
|
||||
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return "upstream_rate_limited", false
|
||||
}
|
||||
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
|
||||
return "upgrade_required", true
|
||||
}
|
||||
@@ -3927,9 +3973,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||||
case strings.Contains(errType, "permission"),
|
||||
strings.Contains(code, "forbidden"):
|
||||
return http.StatusForbidden
|
||||
case strings.Contains(errType, "rate_limit"),
|
||||
strings.Contains(code, "rate_limit"),
|
||||
strings.Contains(code, "insufficient_quota"):
|
||||
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
|
||||
return http.StatusTooManyRequests
|
||||
default:
|
||||
return http.StatusBadGateway
|
||||
|
||||
@@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -424,6 +424,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.True(t, result.OpenAIWSMode)
|
||||
require.Equal(t, 2, result.Usage.InputTokens)
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
@@ -2593,7 +2595,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect
|
||||
require.NoError(t, err)
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false,"service_tier":"flex"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
// 立即关闭客户端,模拟客户端在 relay 期间断连。
|
||||
@@ -2611,6 +2613,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect
|
||||
require.Equal(t, "resp_ingress_disconnect", result.RequestID)
|
||||
require.Equal(t, 2, result.Usage.InputTokens)
|
||||
require.Equal(t, 1, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "flex", *result.ServiceTier)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到断连后的 turn 结果回调")
|
||||
}
|
||||
|
||||
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIWSRateLimitSignalRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCalls []time.Time
|
||||
updateExtra []map[string]any
|
||||
}
|
||||
|
||||
type openAICodexSnapshotAsyncRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCh chan map[string]any
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
type openAICodexExtraListRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtra = append(r.updateExtra, copied)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
if r.updateExtraCh != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCh <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
resetAt := time.Now().Add(2 * time.Hour).Unix()
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "rate_limit_exceeded",
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"resets_at": resetAt,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 501,
|
||||
Name: "openai-ws-rate-limit-event",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-codex-primary-used-percent", "100")
|
||||
w.Header().Set("x-codex-primary-reset-after-seconds", "7200")
|
||||
w.Header().Set("x-codex-primary-window-minutes", "10080")
|
||||
w.Header().Set("x-codex-secondary-used-percent", "3")
|
||||
w.Header().Set("x-codex-secondary-reset-after-seconds", "1800")
|
||||
w.Header().Set("x-codex-secondary-window-minutes", "300")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 502,
|
||||
Name: "openai-ws-rate-limit-handshake",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库")
|
||||
require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
resetAt := time.Now().Add(90 * time.Minute).Unix()
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`),
|
||||
},
|
||||
}
|
||||
captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10)))
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
account := Account{
|
||||
ID: 503,
|
||||
Name: "openai-ingress-rate-limit",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = clientConn.CloseNow() }()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(100),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(12),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
before := time.Now()
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCh:
|
||||
require.Equal(t, 100.0, updates["codex_7d_used_percent"])
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 100% 自动切换限流超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot)
|
||||
|
||||
select {
|
||||
case <-repo.updateExtraCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
t.Fatalf("unexpected rate limit reset at: %v", resetAt)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||
func ptrIntWS(v int) *int { return &v }
|
||||
|
||||
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
|
||||
resetAt := time.Now().Add(6 * 24 * time.Hour)
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
|
||||
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fresh)
|
||||
require.NotNil(t, fresh.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待旧快照补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
|
||||
resetAt := time.Now().Add(4 * 24 * time.Hour)
|
||||
repo := &openAICodexExtraListRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
|
||||
ID: 702,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}}},
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
require.NotNil(t, accounts[0].RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待列表补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
|
||||
}
|
||||
@@ -77,6 +77,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
|
||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||
@@ -178,6 +179,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@@ -225,6 +227,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
@@ -40,13 +40,17 @@ var (
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
InputCostPerTokenPriority float64 `json:"input_cost_per_token_priority"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
OutputCostPerTokenPriority float64 `json:"output_cost_per_token_priority"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
CacheReadInputTokenCostPriority float64 `json:"cache_read_input_token_cost_priority"`
|
||||
LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"`
|
||||
LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"`
|
||||
LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"`
|
||||
SupportsServiceTier bool `json:"supports_service_tier"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
@@ -62,10 +66,14 @@ type PricingRemoteClient interface {
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
type LiteLLMRawEntry struct {
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority"`
|
||||
SupportsServiceTier bool `json:"supports_service_tier"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
@@ -324,14 +332,21 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
LiteLLMProvider: entry.LiteLLMProvider,
|
||||
Mode: entry.Mode,
|
||||
SupportsPromptCaching: entry.SupportsPromptCaching,
|
||||
SupportsServiceTier: entry.SupportsServiceTier,
|
||||
}
|
||||
|
||||
if entry.InputCostPerToken != nil {
|
||||
pricing.InputCostPerToken = *entry.InputCostPerToken
|
||||
}
|
||||
if entry.InputCostPerTokenPriority != nil {
|
||||
pricing.InputCostPerTokenPriority = *entry.InputCostPerTokenPriority
|
||||
}
|
||||
if entry.OutputCostPerToken != nil {
|
||||
pricing.OutputCostPerToken = *entry.OutputCostPerToken
|
||||
}
|
||||
if entry.OutputCostPerTokenPriority != nil {
|
||||
pricing.OutputCostPerTokenPriority = *entry.OutputCostPerTokenPriority
|
||||
}
|
||||
if entry.CacheCreationInputTokenCost != nil {
|
||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||
}
|
||||
@@ -341,6 +356,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
if entry.CacheReadInputTokenCostPriority != nil {
|
||||
pricing.CacheReadInputTokenCostPriority = *entry.CacheReadInputTokenCostPriority
|
||||
}
|
||||
if entry.OutputCostPerImage != nil {
|
||||
pricing.OutputCostPerImage = *entry.OutputCostPerImage
|
||||
}
|
||||
|
||||
@@ -1,11 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParsePricingData_ParsesPriorityAndServiceTierFields(t *testing.T) {
|
||||
svc := &PricingService{}
|
||||
body := []byte(`{
|
||||
"gpt-5.4": {
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"input_cost_per_token_priority": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"output_cost_per_token_priority": 0.00003,
|
||||
"cache_creation_input_token_cost": 0.0000025,
|
||||
"cache_read_input_token_cost": 0.00000025,
|
||||
"cache_read_input_token_cost_priority": 0.0000005,
|
||||
"supports_service_tier": true,
|
||||
"supports_prompt_caching": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
}
|
||||
}`)
|
||||
|
||||
data, err := svc.parsePricingData(body)
|
||||
require.NoError(t, err)
|
||||
pricing := data["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-5, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 5e-7, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) {
|
||||
sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1}
|
||||
gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9}
|
||||
@@ -68,3 +97,64 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T)
|
||||
require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"gpt-5.4": map[string]any{
|
||||
"input_cost_per_token": 2.5e-6,
|
||||
"input_cost_per_token_priority": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
"output_cost_per_token_priority": 30e-6,
|
||||
"cache_read_input_token_cost": 0.25e-6,
|
||||
"cache_read_input_token_cost_priority": 0.5e-6,
|
||||
"supports_service_tier": true,
|
||||
"supports_prompt_caching": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PricingService{}
|
||||
pricingMap, err := svc.parsePricingData(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
pricing := pricingMap["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadInputTokenCost, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) {
|
||||
svc := &PricingService{}
|
||||
pricingData, err := svc.parsePricingData([]byte(`{
|
||||
"gpt-5.4": {
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"input_cost_per_token_priority": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"output_cost_per_token_priority": 0.00003,
|
||||
"cache_read_input_token_cost": 0.00000025,
|
||||
"cache_read_input_token_cost_priority": 0.0000005,
|
||||
"supports_service_tier": true,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
}
|
||||
}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
pricing := pricingData["gpt-5.4"]
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 0.0000025, pricing.InputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 0.000005, pricing.InputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.000015, pricing.OutputCostPerToken, 1e-12)
|
||||
require.InDelta(t, 0.00003, pricing.OutputCostPerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.00000025, pricing.CacheReadInputTokenCost, 1e-12)
|
||||
require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ type RateLimitService struct {
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||
type SuccessfulTestRecoveryResult struct {
|
||||
ClearedError bool
|
||||
ClearedRateLimit bool
|
||||
}
|
||||
|
||||
// AccountRecoveryOptions 控制账号恢复时的附加行为。
|
||||
type AccountRecoveryOptions struct {
|
||||
InvalidateToken bool
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
windowStart time.Time
|
||||
cachedAt time.Time
|
||||
@@ -87,6 +98,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return ErrorPolicySkipped
|
||||
}
|
||||
if account.IsPoolMode() {
|
||||
return ErrorPolicySkipped
|
||||
}
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return ErrorPolicyTempUnscheduled
|
||||
}
|
||||
@@ -96,9 +110,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
|
||||
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
|
||||
if account.IsPoolMode() && !customErrorCodesEnabled {
|
||||
slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
@@ -615,6 +636,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
@@ -878,6 +900,23 @@ func pickSooner(a, b *time.Time) *time.Time {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) {
|
||||
if s == nil || s.accountRepo == nil || account == nil || headers == nil {
|
||||
return
|
||||
}
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil {
|
||||
slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
@@ -1022,6 +1061,42 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SuccessfulTestRecoveryResult{}
|
||||
if account.Status == StatusError {
|
||||
if err := s.accountRepo.ClearError(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedError = true
|
||||
if options.InvalidateToken && s.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := s.tokenCacheInvalidator.InvalidateToken(ctx, account); invalidateErr != nil {
|
||||
slog.Warn("recover_account_state_invalidate_token_failed", "account_id", accountID, "error", invalidateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasRecoverableRuntimeState(account) {
|
||||
if err := s.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.ClearedRateLimit = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RecoverAccountAfterSuccessfulTest 将一次成功测试视为正常请求,
|
||||
// 按需恢复 error / rate-limit / overload / temp-unsched / model-rate-limit 等运行时状态。
|
||||
func (s *RateLimitService) RecoverAccountAfterSuccessfulTest(ctx context.Context, accountID int64) (*SuccessfulTestRecoveryResult, error) {
|
||||
return s.RecoverAccountState(ctx, accountID, AccountRecoveryOptions{})
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
|
||||
return err
|
||||
@@ -1038,6 +1113,36 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasRecoverableRuntimeState(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.RateLimitedAt != nil || account.RateLimitResetAt != nil || account.OverloadUntil != nil || account.TempUnschedulableUntil != nil {
|
||||
return true
|
||||
}
|
||||
if len(account.Extra) == 0 {
|
||||
return false
|
||||
}
|
||||
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
|
||||
}
|
||||
|
||||
func hasNonEmptyMapValue(extra map[string]any, key string) bool {
|
||||
raw, ok := extra[key]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
return len(typed) > 0
|
||||
case map[string]string:
|
||||
return len(typed) > 0
|
||||
case []any:
|
||||
return len(typed) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||
now := time.Now().Unix()
|
||||
if s.tempUnschedCache != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,16 +14,34 @@ import (
|
||||
|
||||
type rateLimitClearRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
getByIDAccount *Account
|
||||
getByIDErr error
|
||||
getByIDCalls int
|
||||
clearErrorCalls int
|
||||
clearRateLimitCalls int
|
||||
clearAntigravityCalls int
|
||||
clearModelRateLimitCalls int
|
||||
clearTempUnschedCalls int
|
||||
clearErrorErr error
|
||||
clearRateLimitErr error
|
||||
clearAntigravityErr error
|
||||
clearModelRateLimitErr error
|
||||
clearTempUnschedulableErr error
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
r.getByIDCalls++
|
||||
if r.getByIDErr != nil {
|
||||
return nil, r.getByIDErr
|
||||
}
|
||||
return r.getByIDAccount, nil
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
r.clearErrorCalls++
|
||||
return r.clearErrorErr
|
||||
}
|
||||
|
||||
func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
r.clearRateLimitCalls++
|
||||
return r.clearRateLimitErr
|
||||
@@ -48,6 +67,11 @@ type tempUnschedCacheRecorder struct {
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type recoverTokenInvalidatorStub struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
||||
return nil
|
||||
}
|
||||
@@ -61,6 +85,11 @@ func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accoun
|
||||
return c.deleteErr
|
||||
}
|
||||
|
||||
func (s *recoverTokenInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
s.accounts = append(s.accounts, account)
|
||||
return s.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
@@ -170,3 +199,108 @@ func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) {
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLimitRelatedState(t *testing.T) {
|
||||
now := time.Now()
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 42,
|
||||
Status: StatusError,
|
||||
RateLimitedAt: &now,
|
||||
TempUnschedulableUntil: &now,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": now.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
"antigravity_quota_scopes": map[string]any{"gemini": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.True(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 1, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{},
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 0, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearAntigravityCalls)
|
||||
require.Equal(t, 0, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 0, repo.clearTempUnschedCalls)
|
||||
require.Empty(t, cache.deletedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearErrorFailed(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 9,
|
||||
Status: StatusError,
|
||||
},
|
||||
clearErrorErr: errors.New("clear error failed"),
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, 1, repo.getByIDCalls)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Equal(t, 0, repo.clearRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountState_InvalidatesOAuthTokenOnErrorRecovery(t *testing.T) {
|
||||
repo := &rateLimitClearRepoStub{
|
||||
getByIDAccount: &Account{
|
||||
ID: 21,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
},
|
||||
}
|
||||
invalidator := &recoverTokenInvalidatorStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc.SetTokenCacheInvalidator(invalidator)
|
||||
|
||||
result, err := svc.RecoverAccountState(context.Background(), 21, AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.ClearedError)
|
||||
require.False(t, result.ClearedRateLimit)
|
||||
require.Equal(t, 1, repo.clearErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
require.Equal(t, int64(21), invalidator.accounts[0].ID)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -141,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type openAI429SnapshotRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedID int64
|
||||
updatedExtra map[string]any
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
||||
repo := &openAI429SnapshotRepo{}
|
||||
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
||||
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
svc.handle429(context.Background(), account, headers, nil)
|
||||
|
||||
if repo.rateLimitedID != account.ID {
|
||||
t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID)
|
||||
}
|
||||
if len(repo.updatedExtra) == 0 {
|
||||
t.Fatal("expected codex snapshot to be persisted on 429")
|
||||
}
|
||||
if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
|
||||
@@ -13,6 +13,7 @@ type ScheduledTestPlan struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover bool `json:"auto_recover"`
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
NextRunAt *time.Time `json:"next_run_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -17,6 +17,7 @@ type ScheduledTestRunnerService struct {
|
||||
planRepo ScheduledTestPlanRepository
|
||||
scheduledSvc *ScheduledTestService
|
||||
accountTestSvc *AccountTestService
|
||||
rateLimitSvc *RateLimitService
|
||||
cfg *config.Config
|
||||
|
||||
cron *cron.Cron
|
||||
@@ -29,12 +30,14 @@ func NewScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
return &ScheduledTestRunnerService{
|
||||
planRepo: planRepo,
|
||||
scheduledSvc: scheduledSvc,
|
||||
accountTestSvc: accountTestSvc,
|
||||
rateLimitSvc: rateLimitSvc,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -127,6 +130,11 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Auto-recover account if test succeeded and auto_recover is enabled.
|
||||
if result.Status == "success" && plan.AutoRecover {
|
||||
s.tryRecoverAccount(ctx, plan.AccountID, plan.ID)
|
||||
}
|
||||
|
||||
nextRun, err := computeNextRun(plan.CronExpression, time.Now())
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err)
|
||||
@@ -137,3 +145,26 @@ func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *Sched
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRecoverAccount attempts to recover an account from recoverable runtime state.
|
||||
func (s *ScheduledTestRunnerService) tryRecoverAccount(ctx context.Context, accountID int64, planID int64) {
|
||||
if s.rateLimitSvc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recovery, err := s.rateLimitSvc.RecoverAccountAfterSuccessfulTest(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover failed: %v", planID, err)
|
||||
return
|
||||
}
|
||||
if recovery == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if recovery.ClearedError {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d recovered from error status", planID, accountID)
|
||||
}
|
||||
if recovery.ClearedRateLimit {
|
||||
logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d cleared rate-limit/runtime state", planID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1194,6 +1194,59 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
|
||||
return ver
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("get rectifier settings: %w", err)
|
||||
}
|
||||
if value == "" {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
|
||||
var settings RectifierSettings
|
||||
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||
return DefaultRectifierSettings(), nil
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// SetRectifierSettings 设置请求整流器配置
|
||||
func (s *SettingService) SetRectifierSettings(ctx context.Context, settings *RectifierSettings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal rectifier settings: %w", err)
|
||||
}
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyRectifierSettings, string(data))
|
||||
}
|
||||
|
||||
// IsSignatureRectifierEnabled 判断签名整流是否启用(总开关 && 签名子开关)
|
||||
func (s *SettingService) IsSignatureRectifierEnabled(ctx context.Context) bool {
|
||||
settings, err := s.GetRectifierSettings(ctx)
|
||||
if err != nil {
|
||||
return true // fail-open: 查询失败时默认启用
|
||||
}
|
||||
return settings.Enabled && settings.ThinkingSignatureEnabled
|
||||
}
|
||||
|
||||
// IsBudgetRectifierEnabled 判断 Budget 整流是否启用(总开关 && Budget 子开关)
|
||||
func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool {
|
||||
settings, err := s.GetRectifierSettings(ctx)
|
||||
if err != nil {
|
||||
return true // fail-open: 查询失败时默认启用
|
||||
}
|
||||
return settings.Enabled && settings.ThinkingBudgetEnabled
|
||||
}
|
||||
|
||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||
if settings == nil {
|
||||
|
||||
@@ -175,3 +175,19 @@ func DefaultStreamTimeoutSettings() *StreamTimeoutSettings {
|
||||
ThresholdWindowMinutes: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"` // 总开关
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流
|
||||
}
|
||||
|
||||
// DefaultRectifierSettings 返回默认的整流器配置(全部启用)
|
||||
func DefaultRectifierSettings() *RectifierSettings {
|
||||
return &RectifierSettings{
|
||||
Enabled: true,
|
||||
ThinkingSignatureEnabled: true,
|
||||
ThinkingBudgetEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,8 @@ type UsageLog struct {
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
|
||||
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
|
||||
@@ -287,9 +287,10 @@ func ProvideScheduledTestRunnerService(
|
||||
planRepo ScheduledTestPlanRepository,
|
||||
scheduledSvc *ScheduledTestService,
|
||||
accountTestSvc *AccountTestService,
|
||||
rateLimitSvc *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *ScheduledTestRunnerService {
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, cfg)
|
||||
svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, rateLimitSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -23,10 +24,19 @@ import (
|
||||
|
||||
// Config paths
|
||||
const (
|
||||
ConfigFileName = "config.yaml"
|
||||
InstallLockFile = ".installed"
|
||||
ConfigFileName = "config.yaml"
|
||||
InstallLockFile = ".installed"
|
||||
defaultUserConcurrency = 5
|
||||
simpleModeAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func setupDefaultAdminConcurrency() int {
|
||||
if strings.EqualFold(strings.TrimSpace(os.Getenv("RUN_MODE")), config.RunModeSimple) {
|
||||
return simpleModeAdminConcurrency
|
||||
}
|
||||
return defaultUserConcurrency
|
||||
}
|
||||
|
||||
// GetDataDir returns the data directory for storing config and lock files.
|
||||
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
|
||||
func GetDataDir() string {
|
||||
@@ -390,7 +400,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
Concurrency: 5,
|
||||
Concurrency: setupDefaultAdminConcurrency(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
@@ -462,7 +472,7 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
APIKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
}{
|
||||
UserConcurrency: 5,
|
||||
UserConcurrency: defaultUserConcurrency,
|
||||
UserBalance: 0,
|
||||
APIKeyPrefix: "sk-",
|
||||
RateMultiplier: 1.0,
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package setup
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecideAdminBootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -49,3 +53,37 @@ func TestDecideAdminBootstrap(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupDefaultAdminConcurrency(t *testing.T) {
|
||||
t.Run("simple mode admin uses higher concurrency", func(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "simple")
|
||||
if got := setupDefaultAdminConcurrency(); got != simpleModeAdminConcurrency {
|
||||
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, simpleModeAdminConcurrency)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("standard mode keeps existing default", func(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "standard")
|
||||
if got := setupDefaultAdminConcurrency(); got != defaultUserConcurrency {
|
||||
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, defaultUserConcurrency)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) {
|
||||
t.Setenv("RUN_MODE", "simple")
|
||||
t.Setenv("DATA_DIR", t.TempDir())
|
||||
|
||||
if err := writeConfigFile(&SetupConfig{}); err != nil {
|
||||
t.Fatalf("writeConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(GetConfigFilePath())
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), "user_concurrency: 5") {
|
||||
t.Fatalf("config missing default user concurrency, got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
-- 070: Add auto_recover column to scheduled_test_plans
|
||||
-- When enabled, automatically recovers account from error/rate-limited state on successful test
|
||||
|
||||
ALTER TABLE scheduled_test_plans ADD COLUMN IF NOT EXISTS auto_recover BOOLEAN NOT NULL DEFAULT false;
|
||||
5
backend/migrations/070_add_usage_log_service_tier.sql
Normal file
5
backend/migrations/070_add_usage_log_service_tier.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE usage_logs
|
||||
ADD COLUMN IF NOT EXISTS service_tier VARCHAR(16);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_service_tier_created_at
|
||||
ON usage_logs (service_tier, created_at);
|
||||
@@ -240,6 +240,16 @@ export async function clearRateLimit(id: number): Promise<Account> {
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Recover account runtime state in one call
|
||||
* @param id - Account ID
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function recoverState(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(`/admin/accounts/${id}/recover-state`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset account quota usage
|
||||
* @param id - Account ID
|
||||
@@ -588,6 +598,7 @@ export const accountsAPI = {
|
||||
getTodayStats,
|
||||
getBatchTodayStats,
|
||||
clearRateLimit,
|
||||
recoverState,
|
||||
resetAccountQuota,
|
||||
getTempUnschedulableStatus,
|
||||
resetTempUnschedulable,
|
||||
|
||||
@@ -273,6 +273,41 @@ export async function updateStreamTimeoutSettings(
|
||||
return data
|
||||
}
|
||||
|
||||
// ==================== Rectifier Settings ====================
|
||||
|
||||
/**
|
||||
* Rectifier settings interface
|
||||
*/
|
||||
export interface RectifierSettings {
|
||||
enabled: boolean
|
||||
thinking_signature_enabled: boolean
|
||||
thinking_budget_enabled: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rectifier settings
|
||||
* @returns Rectifier settings
|
||||
*/
|
||||
export async function getRectifierSettings(): Promise<RectifierSettings> {
|
||||
const { data } = await apiClient.get<RectifierSettings>('/admin/settings/rectifier')
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Update rectifier settings
|
||||
* @param settings - Rectifier settings to update
|
||||
* @returns Updated settings
|
||||
*/
|
||||
export async function updateRectifierSettings(
|
||||
settings: RectifierSettings
|
||||
): Promise<RectifierSettings> {
|
||||
const { data } = await apiClient.put<RectifierSettings>(
|
||||
'/admin/settings/rectifier',
|
||||
settings
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
// ==================== Sora S3 Settings ====================
|
||||
|
||||
export interface SoraS3Settings {
|
||||
@@ -419,6 +454,8 @@ export const settingsAPI = {
|
||||
deleteAdminApiKey,
|
||||
getStreamTimeoutSettings,
|
||||
updateStreamTimeoutSettings,
|
||||
getRectifierSettings,
|
||||
updateRectifierSettings,
|
||||
getSoraS3Settings,
|
||||
updateSoraS3Settings,
|
||||
testSoraS3Connection,
|
||||
|
||||
@@ -73,21 +73,10 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="showQuotaLimit" class="flex items-center gap-1">
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
|
||||
quotaClass
|
||||
]"
|
||||
:title="quotaTooltip"
|
||||
>
|
||||
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M2.25 18.75a60.07 60.07 0 0115.797 2.101c.727.198 1.453-.342 1.453-1.096V18.75M3.75 4.5v.75A.75.75 0 013 6h-.75m0 0v-.375c0-.621.504-1.125 1.125-1.125H20.25M2.25 6v9m18-10.5v.75c0 .414.336.75.75.75h.75m-1.5-1.5h.375c.621 0 1.125.504 1.125 1.125v9.75c0 .621-.504 1.125-1.125 1.125h-.375m1.5-1.5H21a.75.75 0 00-.75.75v.75m0 0H3.75m0 0h-.375a1.125 1.125 0 01-1.125-1.125V15m1.5 1.5v-.75A.75.75 0 003 15h-.75M15 10.5a3 3 0 11-6 0 3 3 0 016 0zm3 0h.008v.008H18V10.5zm-12 0h.008v.008H6V10.5z" />
|
||||
</svg>
|
||||
<span class="font-mono">${{ formatCost(currentQuotaUsed) }}</span>
|
||||
<span class="text-gray-400 dark:text-gray-500">/</span>
|
||||
<span class="font-mono">${{ formatCost(account.quota_limit) }}</span>
|
||||
</span>
|
||||
<div v-if="showDailyQuota || showWeeklyQuota || showTotalQuota" class="flex items-center gap-1">
|
||||
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
|
||||
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
|
||||
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -96,6 +85,7 @@
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import type { Account } from '@/types'
|
||||
import QuotaBadge from './QuotaBadge.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
account: Account
|
||||
@@ -304,46 +294,17 @@ const rpmTooltip = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// 是否显示配额限制(仅 apikey 类型且设置了 quota_limit)
|
||||
const showQuotaLimit = computed(() => {
|
||||
return (
|
||||
props.account.type === 'apikey' &&
|
||||
props.account.quota_limit !== undefined &&
|
||||
props.account.quota_limit !== null &&
|
||||
props.account.quota_limit > 0
|
||||
)
|
||||
// 是否显示各维度配额(仅 apikey 类型)
|
||||
const showDailyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
// 当前已用配额
|
||||
const currentQuotaUsed = computed(() => props.account.quota_used ?? 0)
|
||||
|
||||
// 配额状态样式
|
||||
const quotaClass = computed(() => {
|
||||
if (!showQuotaLimit.value) return ''
|
||||
|
||||
const used = currentQuotaUsed.value
|
||||
const limit = props.account.quota_limit || 0
|
||||
|
||||
if (used >= limit) {
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
|
||||
}
|
||||
if (used >= limit * 0.8) {
|
||||
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
}
|
||||
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
const showWeeklyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
// 配额提示文字
|
||||
const quotaTooltip = computed(() => {
|
||||
if (!showQuotaLimit.value) return ''
|
||||
|
||||
const used = currentQuotaUsed.value
|
||||
const limit = props.account.quota_limit || 0
|
||||
|
||||
if (used >= limit) {
|
||||
return t('admin.accounts.capacity.quota.exceeded')
|
||||
}
|
||||
return t('admin.accounts.capacity.quota.normal')
|
||||
const showTotalQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
// 格式化费用显示
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<!-- Rate Limit Display (429) - Two-line layout -->
|
||||
<div v-if="isRateLimited" class="flex flex-col items-center gap-1">
|
||||
<span class="badge text-xs badge-warning">{{ t('admin.accounts.status.rateLimited') }}</span>
|
||||
<span class="text-[11px] text-gray-400 dark:text-gray-500">{{ rateLimitCountdown }}</span>
|
||||
<span class="text-[11px] text-gray-400 dark:text-gray-500">{{ rateLimitResumeText }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Overload Display (529) - Two-line layout -->
|
||||
@@ -67,9 +67,9 @@
|
||||
</span>
|
||||
<!-- Tooltip -->
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatTime(account.rate_limit_reset_at) }) }}
|
||||
{{ t('admin.accounts.status.rateLimitedUntil', { time: formatDateTime(account.rate_limit_reset_at) }) }}
|
||||
<div
|
||||
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
|
||||
></div>
|
||||
@@ -97,7 +97,7 @@
|
||||
</span>
|
||||
<!-- Tooltip -->
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.modelRateLimitedUntil', { model: formatScopeName(item.model), time: formatTime(item.reset_at) }) }}
|
||||
<div
|
||||
@@ -117,7 +117,7 @@
|
||||
</span>
|
||||
<!-- Tooltip -->
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.overloadedUntil', { time: formatTime(account.overload_until) }) }}
|
||||
<div
|
||||
@@ -132,7 +132,7 @@
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import type { Account } from '@/types'
|
||||
import { formatCountdownWithSuffix, formatTime } from '@/utils/format'
|
||||
import { formatCountdown, formatDateTime, formatCountdownWithSuffix, formatTime } from '@/utils/format'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -231,7 +231,12 @@ const hasError = computed(() => {
|
||||
|
||||
// Computed: countdown text for rate limit (429)
|
||||
const rateLimitCountdown = computed(() => {
|
||||
return formatCountdownWithSuffix(props.account.rate_limit_reset_at)
|
||||
return formatCountdown(props.account.rate_limit_reset_at)
|
||||
})
|
||||
|
||||
const rateLimitResumeText = computed(() => {
|
||||
if (!rateLimitCountdown.value) return ''
|
||||
return t('admin.accounts.status.rateLimitedAutoResume', { time: rateLimitCountdown.value })
|
||||
})
|
||||
|
||||
// Computed: countdown text for overload (529)
|
||||
|
||||
@@ -69,9 +69,39 @@
|
||||
<div v-else class="text-xs text-gray-400">-</div>
|
||||
</template>
|
||||
|
||||
<!-- OpenAI OAuth accounts: show Codex usage from extra field -->
|
||||
<!-- OpenAI OAuth accounts: prefer fresh usage query for active rate-limited rows -->
|
||||
<template v-else-if="account.platform === 'openai' && account.type === 'oauth'">
|
||||
<div v-if="hasCodexUsage" class="space-y-1">
|
||||
<div v-if="preferFetchedOpenAIUsage" class="space-y-1">
|
||||
<UsageProgressBar
|
||||
v-if="usageInfo?.five_hour"
|
||||
label="5h"
|
||||
:utilization="usageInfo.five_hour.utilization"
|
||||
:resets-at="usageInfo.five_hour.resets_at"
|
||||
:window-stats="usageInfo.five_hour.window_stats"
|
||||
color="indigo"
|
||||
/>
|
||||
<UsageProgressBar
|
||||
v-if="usageInfo?.seven_day"
|
||||
label="7d"
|
||||
:utilization="usageInfo.seven_day.utilization"
|
||||
:resets-at="usageInfo.seven_day.resets_at"
|
||||
:window-stats="usageInfo.seven_day.window_stats"
|
||||
color="emerald"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="isActiveOpenAIRateLimited && loading" class="space-y-1.5">
|
||||
<div class="flex items-center gap-1">
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div class="flex items-center gap-1">
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else-if="hasCodexUsage" class="space-y-1">
|
||||
<!-- 5h Window -->
|
||||
<UsageProgressBar
|
||||
v-if="codex5hUsedPercent !== null"
|
||||
@@ -303,15 +333,39 @@
|
||||
<div v-else>
|
||||
<!-- Gemini API Key accounts: show quota info -->
|
||||
<AccountQuotaInfo v-if="account.platform === 'gemini'" :account="account" />
|
||||
<!-- API Key accounts with quota limits: show progress bars -->
|
||||
<div v-else-if="hasApiKeyQuota" class="space-y-1">
|
||||
<UsageProgressBar
|
||||
v-if="quotaDailyBar"
|
||||
label="1d"
|
||||
:utilization="quotaDailyBar.utilization"
|
||||
:resets-at="quotaDailyBar.resetsAt"
|
||||
color="indigo"
|
||||
/>
|
||||
<UsageProgressBar
|
||||
v-if="quotaWeeklyBar"
|
||||
label="7d"
|
||||
:utilization="quotaWeeklyBar.utilization"
|
||||
:resets-at="quotaWeeklyBar.resetsAt"
|
||||
color="emerald"
|
||||
/>
|
||||
<UsageProgressBar
|
||||
v-if="quotaTotalBar"
|
||||
label="total"
|
||||
:utilization="quotaTotalBar.utilization"
|
||||
color="purple"
|
||||
/>
|
||||
</div>
|
||||
<div v-else class="text-xs text-gray-400">-</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageInfo, GeminiCredentials, WindowStats } from '@/types'
|
||||
import { buildOpenAIUsageRefreshKey } from '@/utils/accountUsageRefresh'
|
||||
import { resolveCodexUsageWindow } from '@/utils/codexUsage'
|
||||
import UsageProgressBar from './UsageProgressBar.vue'
|
||||
import AccountQuotaInfo from './AccountQuotaInfo.vue'
|
||||
@@ -373,6 +427,36 @@ const hasOpenAIUsageFallback = computed(() => {
|
||||
return !!usageInfo.value?.five_hour || !!usageInfo.value?.seven_day
|
||||
})
|
||||
|
||||
const isActiveOpenAIRateLimited = computed(() => {
|
||||
if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return false
|
||||
if (!props.account.rate_limit_reset_at) return false
|
||||
const resetAt = Date.parse(props.account.rate_limit_reset_at)
|
||||
return !Number.isNaN(resetAt) && resetAt > Date.now()
|
||||
})
|
||||
|
||||
const preferFetchedOpenAIUsage = computed(() => {
|
||||
return (isActiveOpenAIRateLimited.value || isOpenAICodexSnapshotStale.value) && hasOpenAIUsageFallback.value
|
||||
})
|
||||
|
||||
const openAIUsageRefreshKey = computed(() => buildOpenAIUsageRefreshKey(props.account))
|
||||
|
||||
const isOpenAICodexSnapshotStale = computed(() => {
|
||||
if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return false
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
const updatedAtRaw = extra?.codex_usage_updated_at
|
||||
if (!updatedAtRaw) return true
|
||||
const updatedAt = Date.parse(String(updatedAtRaw))
|
||||
if (Number.isNaN(updatedAt)) return true
|
||||
return Date.now() - updatedAt >= 10 * 60 * 1000
|
||||
})
|
||||
|
||||
const shouldAutoLoadUsageOnMount = computed(() => {
|
||||
if (props.account.platform === 'openai' && props.account.type === 'oauth') {
|
||||
return isActiveOpenAIRateLimited.value || !hasCodexUsage.value || isOpenAICodexSnapshotStale.value
|
||||
}
|
||||
return shouldFetchUsage.value
|
||||
})
|
||||
|
||||
const codex5hUsedPercent = computed(() => codex5hWindow.value.usedPercent)
|
||||
const codex5hResetAt = computed(() => codex5hWindow.value.resetAt)
|
||||
const codex7dUsedPercent = computed(() => codex7dWindow.value.usedPercent)
|
||||
@@ -748,7 +832,71 @@ const loadUsage = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// ===== API Key quota progress bars =====
|
||||
|
||||
interface QuotaBarInfo {
|
||||
utilization: number
|
||||
resetsAt: string | null
|
||||
}
|
||||
|
||||
const makeQuotaBar = (
|
||||
used: number,
|
||||
limit: number,
|
||||
startKey?: string
|
||||
): QuotaBarInfo => {
|
||||
const utilization = limit > 0 ? (used / limit) * 100 : 0
|
||||
let resetsAt: string | null = null
|
||||
if (startKey) {
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
const startStr = extra?.[startKey] as string | undefined
|
||||
if (startStr) {
|
||||
const startDate = new Date(startStr)
|
||||
const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
|
||||
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
|
||||
}
|
||||
}
|
||||
return { utilization, resetsAt }
|
||||
}
|
||||
|
||||
const hasApiKeyQuota = computed(() => {
|
||||
if (props.account.type !== 'apikey') return false
|
||||
return (
|
||||
(props.account.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account.quota_weekly_limit ?? 0) > 0 ||
|
||||
(props.account.quota_limit ?? 0) > 0
|
||||
)
|
||||
})
|
||||
|
||||
const quotaDailyBar = computed((): QuotaBarInfo | null => {
|
||||
const limit = props.account.quota_daily_limit ?? 0
|
||||
if (limit <= 0) return null
|
||||
return makeQuotaBar(props.account.quota_daily_used ?? 0, limit, 'quota_daily_start')
|
||||
})
|
||||
|
||||
const quotaWeeklyBar = computed((): QuotaBarInfo | null => {
|
||||
const limit = props.account.quota_weekly_limit ?? 0
|
||||
if (limit <= 0) return null
|
||||
return makeQuotaBar(props.account.quota_weekly_used ?? 0, limit, 'quota_weekly_start')
|
||||
})
|
||||
|
||||
const quotaTotalBar = computed((): QuotaBarInfo | null => {
|
||||
const limit = props.account.quota_limit ?? 0
|
||||
if (limit <= 0) return null
|
||||
return makeQuotaBar(props.account.quota_used ?? 0, limit)
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
if (!shouldAutoLoadUsageOnMount.value) return
|
||||
loadUsage()
|
||||
})
|
||||
|
||||
watch(openAIUsageRefreshKey, (nextKey, prevKey) => {
|
||||
if (!prevKey || nextKey === prevKey) return
|
||||
if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return
|
||||
if (!isActiveOpenAIRateLimited.value && hasCodexUsage.value && !isOpenAICodexSnapshotStale.value) return
|
||||
|
||||
loadUsage().catch((e) => {
|
||||
console.error('Failed to refresh OpenAI usage:', e)
|
||||
})
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -1127,6 +1127,58 @@
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Pool Mode Section -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom Error Codes Section -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@@ -1228,7 +1280,22 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<QuotaLimitCard v-if="form.type === 'apikey'" v-model="editQuotaLimit" />
|
||||
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaLimitHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<QuotaLimitCard
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI OAuth Model Mapping (OAuth 类型没有 apikey 容器,需要独立的模型映射区域) -->
|
||||
<div
|
||||
@@ -2609,9 +2676,15 @@ const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
|
||||
const MAX_POOL_MODE_RETRY_COUNT = 10
|
||||
const poolModeEnabled = ref(false)
|
||||
const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT)
|
||||
const customErrorCodesEnabled = ref(false)
|
||||
const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
@@ -3272,6 +3345,8 @@ const resetForm = () => {
|
||||
apiKeyBaseUrl.value = 'https://api.anthropic.com'
|
||||
apiKeyValue.value = ''
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
modelMappings.value = []
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = [...claudeModels] // Default fill related models
|
||||
@@ -3281,6 +3356,8 @@ const resetForm = () => {
|
||||
fetchAntigravityDefaultMappings().then(mappings => {
|
||||
antigravityModelMappings.value = [...mappings]
|
||||
})
|
||||
poolModeEnabled.value = false
|
||||
poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
customErrorCodesEnabled.value = false
|
||||
selectedErrorCodes.value = []
|
||||
customErrorCodeInput.value = null
|
||||
@@ -3433,6 +3510,20 @@ const handleMixedChannelCancel = () => {
|
||||
clearMixedChannelDialog()
|
||||
}
|
||||
|
||||
const normalizePoolModeRetryCount = (value: number) => {
|
||||
if (!Number.isFinite(value)) {
|
||||
return DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
}
|
||||
const normalized = Math.trunc(value)
|
||||
if (normalized < 0) {
|
||||
return 0
|
||||
}
|
||||
if (normalized > MAX_POOL_MODE_RETRY_COUNT) {
|
||||
return MAX_POOL_MODE_RETRY_COUNT
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
// For OAuth-based type, handle OAuth flow (goes to step 2)
|
||||
if (isOAuthFlow.value) {
|
||||
@@ -3532,6 +3623,12 @@ const handleSubmit = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// Add pool mode if enabled
|
||||
if (poolModeEnabled.value) {
|
||||
credentials.pool_mode = true
|
||||
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
}
|
||||
|
||||
// Add custom error codes if enabled
|
||||
if (customErrorCodesEnabled.value) {
|
||||
credentials.custom_error_codes_enabled = true
|
||||
@@ -3686,10 +3783,22 @@ const createAccountAndFinish = async (
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
// Inject quota_limit for apikey accounts
|
||||
// Inject quota limits for apikey accounts
|
||||
let finalExtra = extra
|
||||
if (type === 'apikey' && editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
finalExtra = { ...(extra || {}), quota_limit: editQuotaLimit.value }
|
||||
if (type === 'apikey') {
|
||||
const quotaExtra: Record<string, unknown> = { ...(extra || {}) }
|
||||
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
quotaExtra.quota_limit = editQuotaLimit.value
|
||||
}
|
||||
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
|
||||
quotaExtra.quota_daily_limit = editQuotaDailyLimit.value
|
||||
}
|
||||
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
||||
quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
||||
}
|
||||
if (Object.keys(quotaExtra).length > 0) {
|
||||
finalExtra = quotaExtra
|
||||
}
|
||||
}
|
||||
await doCreateAccount({
|
||||
name: form.name,
|
||||
|
||||
@@ -251,6 +251,58 @@
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Pool Mode Section -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Custom Error Codes Section -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@@ -904,7 +956,22 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<QuotaLimitCard v-if="account?.type === 'apikey'" v-model="editQuotaLimit" />
|
||||
<div v-if="account?.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaLimitHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<QuotaLimitCard
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI OAuth Codex 官方客户端限制开关 -->
|
||||
<div
|
||||
@@ -1483,6 +1550,10 @@ const editApiKey = ref('')
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
|
||||
const MAX_POOL_MODE_RETRY_COUNT = 10
|
||||
const poolModeEnabled = ref(false)
|
||||
const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT)
|
||||
const customErrorCodesEnabled = ref(false)
|
||||
const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
@@ -1535,6 +1606,8 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OF
|
||||
const codexCLIOnlyEnabled = ref(false)
|
||||
const anthropicPassthroughEnabled = ref(false)
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
|
||||
@@ -1641,6 +1714,20 @@ const expiresAtInput = computed({
|
||||
})
|
||||
|
||||
// Watchers
|
||||
const normalizePoolModeRetryCount = (value: number) => {
|
||||
if (!Number.isFinite(value)) {
|
||||
return DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
}
|
||||
const normalized = Math.trunc(value)
|
||||
if (normalized < 0) {
|
||||
return 0
|
||||
}
|
||||
if (normalized > MAX_POOL_MODE_RETRY_COUNT) {
|
||||
return MAX_POOL_MODE_RETRY_COUNT
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
watch(
|
||||
() => props.account,
|
||||
(newAccount) => {
|
||||
@@ -1704,8 +1791,14 @@ watch(
|
||||
if (newAccount.type === 'apikey') {
|
||||
const quotaVal = extra?.quota_limit as number | undefined
|
||||
editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null
|
||||
const dailyVal = extra?.quota_daily_limit as number | undefined
|
||||
editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null
|
||||
const weeklyVal = extra?.quota_weekly_limit as number | undefined
|
||||
editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null
|
||||
} else {
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
}
|
||||
|
||||
// Load antigravity model mapping (Antigravity 只支持映射模式)
|
||||
@@ -1782,6 +1875,12 @@ watch(
|
||||
allowedModels.value = []
|
||||
}
|
||||
|
||||
// Load pool mode
|
||||
poolModeEnabled.value = credentials.pool_mode === true
|
||||
poolModeRetryCount.value = normalizePoolModeRetryCount(
|
||||
Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT)
|
||||
)
|
||||
|
||||
// Load custom error codes
|
||||
customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true
|
||||
const existingErrorCodes = credentials.custom_error_codes as number[] | undefined
|
||||
@@ -1828,6 +1927,8 @@ watch(
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
poolModeEnabled.value = false
|
||||
poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
customErrorCodesEnabled.value = false
|
||||
selectedErrorCodes.value = []
|
||||
}
|
||||
@@ -2288,6 +2389,15 @@ const handleSubmit = async () => {
|
||||
newCredentials.model_mapping = currentCredentials.model_mapping
|
||||
}
|
||||
|
||||
// Add pool mode if enabled
|
||||
if (poolModeEnabled.value) {
|
||||
newCredentials.pool_mode = true
|
||||
newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
} else {
|
||||
delete newCredentials.pool_mode
|
||||
delete newCredentials.pool_mode_retry_count
|
||||
}
|
||||
|
||||
// Add custom error codes if enabled
|
||||
if (customErrorCodesEnabled.value) {
|
||||
newCredentials.custom_error_codes_enabled = true
|
||||
@@ -2525,6 +2635,16 @@ const handleSubmit = async () => {
|
||||
} else {
|
||||
delete newExtra.quota_limit
|
||||
}
|
||||
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
|
||||
newExtra.quota_daily_limit = editQuotaDailyLimit.value
|
||||
} else {
|
||||
delete newExtra.quota_daily_limit
|
||||
}
|
||||
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
||||
newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
||||
} else {
|
||||
delete newExtra.quota_weekly_limit
|
||||
}
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
|
||||
49
frontend/src/components/account/QuotaBadge.vue
Normal file
49
frontend/src/components/account/QuotaBadge.vue
Normal file
@@ -0,0 +1,49 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
const props = defineProps<{
|
||||
used: number
|
||||
limit: number
|
||||
label?: string // 文字前缀,如 "D" / "W";不传时显示 icon
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const badgeClass = computed(() => {
|
||||
if (props.used >= props.limit) {
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
|
||||
}
|
||||
if (props.used >= props.limit * 0.8) {
|
||||
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
}
|
||||
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
})
|
||||
|
||||
const tooltip = computed(() => {
|
||||
if (props.used >= props.limit) {
|
||||
return t('admin.accounts.capacity.quota.exceeded')
|
||||
}
|
||||
return t('admin.accounts.capacity.quota.normal')
|
||||
})
|
||||
|
||||
const fmt = (v: number) => v.toFixed(2)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center gap-1 rounded-md px-1.5 py-px text-[10px] font-medium leading-tight',
|
||||
badgeClass
|
||||
]"
|
||||
:title="tooltip"
|
||||
>
|
||||
<span v-if="label" class="font-semibold opacity-70">{{ label }}</span>
|
||||
<svg v-else class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M2.25 18.75a60.07 60.07 0 0115.797 2.101c.727.198 1.453-.342 1.453-1.096V18.75M3.75 4.5v.75A.75.75 0 013 6h-.75m0 0v-.375c0-.621.504-1.125 1.125-1.125H20.25M2.25 6v9m18-10.5v.75c0 .414.336.75.75.75h.75m-1.5-1.5h.375c.621 0 1.125.504 1.125 1.125v9.75c0 .621-.504 1.125-1.125 1.125h-.375m1.5-1.5H21a.75.75 0 00-.75.75v.75m0 0H3.75m0 0h-.375a1.125 1.125 0 01-1.125-1.125V15m1.5 1.5v-.75A.75.75 0 003 15h-.75M15 10.5a3 3 0 11-6 0 3 3 0 016 0zm3 0h.008v.008H18V10.5zm-12 0h.008v.008H6V10.5z" />
|
||||
</svg>
|
||||
<span class="font-mono">${{ fmt(used) }}</span>
|
||||
<span class="text-gray-400 dark:text-gray-500">/</span>
|
||||
<span class="font-mono">${{ fmt(limit) }}</span>
|
||||
</span>
|
||||
</template>
|
||||
@@ -1,50 +1,59 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: number | null
|
||||
totalLimit: number | null
|
||||
dailyLimit: number | null
|
||||
weeklyLimit: number | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: number | null]
|
||||
'update:totalLimit': [value: number | null]
|
||||
'update:dailyLimit': [value: number | null]
|
||||
'update:weeklyLimit': [value: number | null]
|
||||
}>()
|
||||
|
||||
const enabled = ref(props.modelValue != null && props.modelValue > 0)
|
||||
|
||||
// Sync enabled state when modelValue changes externally (e.g. account load)
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
(val) => {
|
||||
enabled.value = val != null && val > 0
|
||||
}
|
||||
const enabled = computed(() =>
|
||||
(props.totalLimit != null && props.totalLimit > 0) ||
|
||||
(props.dailyLimit != null && props.dailyLimit > 0) ||
|
||||
(props.weeklyLimit != null && props.weeklyLimit > 0)
|
||||
)
|
||||
|
||||
// When toggle is turned off, clear the value
|
||||
const localEnabled = ref(enabled.value)
|
||||
|
||||
// Sync when props change externally
|
||||
watch(enabled, (val) => {
|
||||
localEnabled.value = val
|
||||
})
|
||||
|
||||
// When toggle is turned off, clear all values
|
||||
watch(localEnabled, (val) => {
|
||||
if (!val) {
|
||||
emit('update:modelValue', null)
|
||||
emit('update:totalLimit', null)
|
||||
emit('update:dailyLimit', null)
|
||||
emit('update:weeklyLimit', null)
|
||||
}
|
||||
})
|
||||
|
||||
const onInput = (e: Event) => {
|
||||
const onTotalInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:modelValue', Number.isNaN(raw) ? null : raw)
|
||||
emit('update:totalLimit', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
const onDailyInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:dailyLimit', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
const onWeeklyInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.quotaLimitHint') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.quotaLimitToggle') }}</label>
|
||||
@@ -54,29 +63,30 @@ const onInput = (e: Event) => {
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="enabled = !enabled"
|
||||
@click="localEnabled = !localEnabled"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
enabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
localEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
enabled ? 'translate-x-5' : 'translate-x-0'
|
||||
localEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="enabled" class="space-y-3">
|
||||
<div v-if="localEnabled" class="space-y-3">
|
||||
<!-- 日配额 -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.quotaLimitAmount') }}</label>
|
||||
<label class="input-label">{{ t('admin.accounts.quotaDailyLimit') }}</label>
|
||||
<div class="relative">
|
||||
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
|
||||
<input
|
||||
:value="modelValue"
|
||||
@input="onInput"
|
||||
:value="dailyLimit"
|
||||
@input="onDailyInput"
|
||||
type="number"
|
||||
min="0"
|
||||
step="0.01"
|
||||
@@ -84,9 +94,44 @@ const onInput = (e: Event) => {
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaLimitAmountHint') }}</p>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaDailyLimitHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 周配额 -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.quotaWeeklyLimit') }}</label>
|
||||
<div class="relative">
|
||||
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
|
||||
<input
|
||||
:value="weeklyLimit"
|
||||
@input="onWeeklyInput"
|
||||
type="number"
|
||||
min="0"
|
||||
step="0.01"
|
||||
class="input pl-7"
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaWeeklyLimitHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 总配额 -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.quotaTotalLimit') }}</label>
|
||||
<div class="relative">
|
||||
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
|
||||
<input
|
||||
:value="totalLimit"
|
||||
@input="onTotalInput"
|
||||
type="number"
|
||||
min="0"
|
||||
step="0.01"
|
||||
class="input pl-7"
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaTotalLimitHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -29,6 +29,10 @@
|
||||
</div>
|
||||
|
||||
<div v-else class="space-y-4">
|
||||
<div class="rounded-lg border border-emerald-200 bg-emerald-50 p-3 text-sm text-emerald-800 dark:border-emerald-500/30 dark:bg-emerald-500/10 dark:text-emerald-300">
|
||||
{{ t('admin.accounts.recoverStateHint') }}
|
||||
</div>
|
||||
|
||||
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.tempUnschedulable.accountName') }}
|
||||
@@ -131,7 +135,7 @@
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||
></path>
|
||||
</svg>
|
||||
{{ t('admin.accounts.tempUnschedulable.reset') }}
|
||||
{{ t('admin.accounts.recoverState') }}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
@@ -154,7 +158,7 @@ const props = defineProps<{
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
reset: []
|
||||
reset: [account: Account]
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
@@ -225,12 +229,12 @@ const handleReset = async () => {
|
||||
if (!props.account) return
|
||||
resetting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.resetTempUnschedulable(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.tempUnschedulable.resetSuccess'))
|
||||
emit('reset')
|
||||
const updated = await adminAPI.accounts.recoverState(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.recoverStateSuccess'))
|
||||
emit('reset', updated)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error?.message || t('admin.accounts.tempUnschedulable.resetFailed'))
|
||||
appStore.showError(error?.message || t('admin.accounts.recoverStateFailed'))
|
||||
} finally {
|
||||
resetting.value = false
|
||||
}
|
||||
|
||||
@@ -1,30 +1,5 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- Window stats row (above progress bar, left-right aligned with progress bar) -->
|
||||
<div
|
||||
v-if="windowStats"
|
||||
class="mb-0.5 flex items-center justify-between"
|
||||
:title="statsTitle || t('admin.accounts.usageWindow.statsTitle')"
|
||||
>
|
||||
<div
|
||||
class="flex cursor-help items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatRequests }} req
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatTokens }}
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"> A ${{ formatAccountCost }} </span>
|
||||
<span
|
||||
v-if="windowStats?.user_cost != null"
|
||||
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
|
||||
>
|
||||
U ${{ formatUserCost }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Progress bar row -->
|
||||
<div class="flex items-center gap-1">
|
||||
<!-- Label badge (fixed width for alignment) -->
|
||||
@@ -57,7 +32,6 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import type { WindowStats } from '@/types'
|
||||
|
||||
const props = defineProps<{
|
||||
@@ -66,11 +40,8 @@ const props = defineProps<{
|
||||
resetsAt?: string | null
|
||||
color: 'indigo' | 'emerald' | 'purple' | 'amber'
|
||||
windowStats?: WindowStats | null
|
||||
statsTitle?: string
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
// Label background colors
|
||||
const labelClass = computed(() => {
|
||||
const colors = {
|
||||
@@ -117,12 +88,12 @@ const displayPercent = computed(() => {
|
||||
|
||||
// Format reset time
|
||||
const formatResetTime = computed(() => {
|
||||
if (!props.resetsAt) return t('common.notAvailable')
|
||||
if (!props.resetsAt) return '-'
|
||||
const date = new Date(props.resetsAt)
|
||||
const now = new Date()
|
||||
const diffMs = date.getTime() - now.getTime()
|
||||
|
||||
if (diffMs <= 0) return t('common.now')
|
||||
if (diffMs <= 0) return '现在'
|
||||
|
||||
const diffHours = Math.floor(diffMs / (1000 * 60 * 60))
|
||||
const diffMins = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60))
|
||||
@@ -137,31 +108,4 @@ const formatResetTime = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// Format window stats
|
||||
const formatRequests = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const r = props.windowStats.requests
|
||||
if (r >= 1000000) return `${(r / 1000000).toFixed(1)}M`
|
||||
if (r >= 1000) return `${(r / 1000).toFixed(1)}K`
|
||||
return r.toString()
|
||||
})
|
||||
|
||||
const formatTokens = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const t = props.windowStats.tokens
|
||||
if (t >= 1000000000) return `${(t / 1000000000).toFixed(1)}B`
|
||||
if (t >= 1000000) return `${(t / 1000000).toFixed(1)}M`
|
||||
if (t >= 1000) return `${(t / 1000).toFixed(1)}K`
|
||||
return t.toString()
|
||||
})
|
||||
|
||||
const formatAccountCost = computed(() => {
|
||||
if (!props.windowStats) return '0.00'
|
||||
return props.windowStats.cost.toFixed(2)
|
||||
})
|
||||
|
||||
const formatUserCost = computed(() => {
|
||||
if (!props.windowStats || props.windowStats.user_cost == null) return '0.00'
|
||||
return props.windowStats.user_cost.toFixed(2)
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -68,6 +68,102 @@ describe('AccountUsageCell', () => {
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.gemini3Image|70|2026-03-01T09:00:00Z')
|
||||
})
|
||||
|
||||
|
||||
it('OpenAI OAuth 快照已过期时首屏会重新请求 usage', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
five_hour: {
|
||||
utilization: 15,
|
||||
resets_at: '2026-03-08T12:00:00Z',
|
||||
remaining_seconds: 3600,
|
||||
window_stats: {
|
||||
requests: 3,
|
||||
tokens: 300,
|
||||
cost: 0.03,
|
||||
standard_cost: 0.03,
|
||||
user_cost: 0.03
|
||||
}
|
||||
},
|
||||
seven_day: {
|
||||
utilization: 77,
|
||||
resets_at: '2026-03-13T12:00:00Z',
|
||||
remaining_seconds: 3600,
|
||||
window_stats: {
|
||||
requests: 3,
|
||||
tokens: 300,
|
||||
cost: 0.03,
|
||||
standard_cost: 0.03,
|
||||
user_cost: 0.03
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2000,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
extra: {
|
||||
codex_usage_updated_at: '2026-03-07T00:00:00Z',
|
||||
codex_5h_used_percent: 12,
|
||||
codex_5h_reset_at: '2026-03-08T12:00:00Z',
|
||||
codex_7d_used_percent: 34,
|
||||
codex_7d_reset_at: '2026-03-13T12:00:00Z'
|
||||
}
|
||||
} as any
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'],
|
||||
template: '<div class="usage-bar">{{ label }}|{{ utilization }}|{{ windowStats?.tokens }}</div>'
|
||||
},
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(getUsage).toHaveBeenCalledWith(2000)
|
||||
expect(wrapper.text()).toContain('5h|15|300')
|
||||
expect(wrapper.text()).toContain('7d|77|300')
|
||||
})
|
||||
|
||||
it('OpenAI OAuth 有现成快照且未限额时不会首屏请求 usage', async () => {
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2001,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
extra: {
|
||||
codex_usage_updated_at: '2099-03-07T10:00:00Z',
|
||||
codex_5h_used_percent: 12,
|
||||
codex_5h_reset_at: '2099-03-07T12:00:00Z',
|
||||
codex_7d_used_percent: 34,
|
||||
codex_7d_reset_at: '2099-03-13T12:00:00Z'
|
||||
}
|
||||
} as any
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'],
|
||||
template: '<div class="usage-bar">{{ label }}|{{ utilization }}</div>'
|
||||
},
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(getUsage).not.toHaveBeenCalled()
|
||||
expect(wrapper.text()).toContain('5h|12')
|
||||
expect(wrapper.text()).toContain('7d|34')
|
||||
})
|
||||
|
||||
it('OpenAI OAuth 在无 codex 快照时会回退显示 usage 接口窗口', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
five_hour: {
|
||||
@@ -122,4 +218,137 @@ describe('AccountUsageCell', () => {
|
||||
expect(wrapper.text()).toContain('5h|0|27700')
|
||||
expect(wrapper.text()).toContain('7d|0|27700')
|
||||
})
|
||||
|
||||
it('OpenAI OAuth 在行数据刷新但仍无 codex 快照时会重新拉取 usage', async () => {
|
||||
getUsage
|
||||
.mockResolvedValueOnce({
|
||||
five_hour: {
|
||||
utilization: 0,
|
||||
resets_at: null,
|
||||
remaining_seconds: 0,
|
||||
window_stats: {
|
||||
requests: 1,
|
||||
tokens: 100,
|
||||
cost: 0.01,
|
||||
standard_cost: 0.01,
|
||||
user_cost: 0.01
|
||||
}
|
||||
},
|
||||
seven_day: null
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
five_hour: {
|
||||
utilization: 0,
|
||||
resets_at: null,
|
||||
remaining_seconds: 0,
|
||||
window_stats: {
|
||||
requests: 2,
|
||||
tokens: 200,
|
||||
cost: 0.02,
|
||||
standard_cost: 0.02,
|
||||
user_cost: 0.02
|
||||
}
|
||||
},
|
||||
seven_day: null
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2003,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
updated_at: '2026-03-07T10:00:00Z',
|
||||
extra: {}
|
||||
} as any
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'],
|
||||
template: '<div class="usage-bar">{{ label }}|{{ utilization }}|{{ windowStats?.tokens }}</div>'
|
||||
},
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(wrapper.text()).toContain('5h|0|100')
|
||||
expect(getUsage).toHaveBeenCalledTimes(1)
|
||||
|
||||
await wrapper.setProps({
|
||||
account: {
|
||||
id: 2003,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
updated_at: '2026-03-07T10:01:00Z',
|
||||
extra: {}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
expect(getUsage).toHaveBeenCalledTimes(2)
|
||||
expect(wrapper.text()).toContain('5h|0|200')
|
||||
})
|
||||
|
||||
it('OpenAI OAuth 已限额时首屏优先展示重新查询后的 usage,而不是旧 codex 快照', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
five_hour: {
|
||||
utilization: 100,
|
||||
resets_at: '2026-03-07T12:00:00Z',
|
||||
remaining_seconds: 3600,
|
||||
window_stats: {
|
||||
requests: 211,
|
||||
tokens: 106540000,
|
||||
cost: 38.13,
|
||||
standard_cost: 38.13,
|
||||
user_cost: 38.13
|
||||
}
|
||||
},
|
||||
seven_day: {
|
||||
utilization: 100,
|
||||
resets_at: '2026-03-13T12:00:00Z',
|
||||
remaining_seconds: 3600,
|
||||
window_stats: {
|
||||
requests: 211,
|
||||
tokens: 106540000,
|
||||
cost: 38.13,
|
||||
standard_cost: 38.13,
|
||||
user_cost: 38.13
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2004,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
rate_limit_reset_at: '2099-03-07T12:00:00Z',
|
||||
extra: {
|
||||
codex_5h_used_percent: 0,
|
||||
codex_7d_used_percent: 0
|
||||
}
|
||||
} as any
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'],
|
||||
template: '<div class="usage-bar">{{ label }}|{{ utilization }}|{{ windowStats?.tokens }}</div>'
|
||||
},
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(getUsage).toHaveBeenCalledWith(2004)
|
||||
expect(wrapper.text()).toContain('5h|100|106540000')
|
||||
expect(wrapper.text()).toContain('7d|100|106540000')
|
||||
expect(wrapper.text()).not.toContain('5h|0|')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -32,14 +32,10 @@
|
||||
{{ t('admin.accounts.refreshToken') }}
|
||||
</button>
|
||||
</template>
|
||||
<div v-if="account.status === 'error' || isRateLimited || isOverloaded" class="my-1 border-t border-gray-100 dark:border-dark-700"></div>
|
||||
<button v-if="account.status === 'error'" @click="$emit('reset-status', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-yellow-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<div v-if="hasRecoverableState" class="my-1 border-t border-gray-100 dark:border-dark-700"></div>
|
||||
<button v-if="hasRecoverableState" @click="$emit('recover-state', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-emerald-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="sync" size="sm" />
|
||||
{{ t('admin.accounts.resetStatus') }}
|
||||
</button>
|
||||
<button v-if="isRateLimited || isOverloaded" @click="$emit('clear-rate-limit', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-amber-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="clock" size="sm" />
|
||||
{{ t('admin.accounts.clearRateLimit') }}
|
||||
{{ t('admin.accounts.recoverState') }}
|
||||
</button>
|
||||
<button v-if="hasQuotaLimit" @click="$emit('reset-quota', account); $emit('close')" class="flex w-full items-center gap-2 px-4 py-2 text-sm text-teal-600 hover:bg-gray-100 dark:hover:bg-dark-700">
|
||||
<Icon name="refresh" size="sm" />
|
||||
@@ -59,7 +55,7 @@ import { Icon } from '@/components/icons'
|
||||
import type { Account } from '@/types'
|
||||
|
||||
const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>()
|
||||
const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit', 'reset-quota'])
|
||||
const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'recover-state', 'reset-quota'])
|
||||
const { t } = useI18n()
|
||||
const isRateLimited = computed(() => {
|
||||
if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) {
|
||||
@@ -75,11 +71,16 @@ const isRateLimited = computed(() => {
|
||||
return false
|
||||
})
|
||||
const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date())
|
||||
const isTempUnschedulable = computed(() => props.account?.temp_unschedulable_until && new Date(props.account.temp_unschedulable_until) > new Date())
|
||||
const hasRecoverableState = computed(() => {
|
||||
return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value)
|
||||
})
|
||||
const hasQuotaLimit = computed(() => {
|
||||
return props.account?.type === 'apikey' &&
|
||||
props.account?.quota_limit !== undefined &&
|
||||
props.account?.quota_limit !== null &&
|
||||
props.account?.quota_limit > 0
|
||||
return props.account?.type === 'apikey' && (
|
||||
(props.account?.quota_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_weekly_limit ?? 0) > 0
|
||||
)
|
||||
})
|
||||
|
||||
const handleKeydown = (event: KeyboardEvent) => {
|
||||
|
||||
@@ -41,8 +41,24 @@
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
<label class="mb-1 flex items-center gap-1 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.cronExpression') }}
|
||||
<HelpTooltip>
|
||||
<template #trigger>
|
||||
<span class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full border border-gray-400/70 text-[10px] font-semibold text-gray-400 transition-colors hover:border-primary-500 hover:text-primary-600 dark:border-gray-500 dark:text-gray-500 dark:hover:border-primary-400 dark:hover:text-primary-400">
|
||||
?
|
||||
</span>
|
||||
</template>
|
||||
<div class="space-y-1.5">
|
||||
<p class="font-medium">{{ t('admin.scheduledTests.cronTooltipTitle') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipMeaning') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleEvery30Min') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleHourly') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleDaily') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleWeekly') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipRange') }}</p>
|
||||
</div>
|
||||
</HelpTooltip>
|
||||
</label>
|
||||
<Input
|
||||
v-model="newPlan.cron_expression"
|
||||
@@ -51,8 +67,22 @@
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
<label class="mb-1 flex items-center gap-1 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.maxResults') }}
|
||||
<HelpTooltip>
|
||||
<template #trigger>
|
||||
<span class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full border border-gray-400/70 text-[10px] font-semibold text-gray-400 transition-colors hover:border-primary-500 hover:text-primary-600 dark:border-gray-500 dark:text-gray-500 dark:hover:border-primary-400 dark:hover:text-primary-400">
|
||||
?
|
||||
</span>
|
||||
</template>
|
||||
<div class="space-y-1.5">
|
||||
<p class="font-medium">{{ t('admin.scheduledTests.maxResultsTooltipTitle') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipMeaning') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipBody') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipExample') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipRange') }}</p>
|
||||
</div>
|
||||
</HelpTooltip>
|
||||
</label>
|
||||
<Input
|
||||
v-model="newPlan.max_results"
|
||||
@@ -66,6 +96,17 @@
|
||||
{{ t('admin.scheduledTests.enabled') }}
|
||||
</label>
|
||||
</div>
|
||||
<div class="flex items-end">
|
||||
<div>
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<Toggle v-model="newPlan.auto_recover" />
|
||||
{{ t('admin.scheduledTests.autoRecover') }}
|
||||
</label>
|
||||
<p class="mt-0.5 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{ t('admin.scheduledTests.autoRecoverHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3 flex justify-end gap-2">
|
||||
<button
|
||||
@@ -135,6 +176,14 @@
|
||||
{{ plan.enabled ? t('admin.scheduledTests.enabled') : '' }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Auto Recover Badge -->
|
||||
<span
|
||||
v-if="plan.auto_recover"
|
||||
class="inline-flex items-center rounded-full bg-emerald-100 px-2 py-0.5 text-xs font-medium text-emerald-700 dark:bg-emerald-500/20 dark:text-emerald-400"
|
||||
>
|
||||
{{ t('admin.scheduledTests.autoRecover') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-3">
|
||||
@@ -202,8 +251,24 @@
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
<label class="mb-1 flex items-center gap-1 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.cronExpression') }}
|
||||
<HelpTooltip>
|
||||
<template #trigger>
|
||||
<span class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full border border-gray-400/70 text-[10px] font-semibold text-gray-400 transition-colors hover:border-primary-500 hover:text-primary-600 dark:border-gray-500 dark:text-gray-500 dark:hover:border-primary-400 dark:hover:text-primary-400">
|
||||
?
|
||||
</span>
|
||||
</template>
|
||||
<div class="space-y-1.5">
|
||||
<p class="font-medium">{{ t('admin.scheduledTests.cronTooltipTitle') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipMeaning') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleEvery30Min') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleHourly') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleDaily') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipExampleWeekly') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.cronTooltipRange') }}</p>
|
||||
</div>
|
||||
</HelpTooltip>
|
||||
</label>
|
||||
<Input
|
||||
v-model="editForm.cron_expression"
|
||||
@@ -212,8 +277,22 @@
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
<label class="mb-1 flex items-center gap-1 text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.scheduledTests.maxResults') }}
|
||||
<HelpTooltip>
|
||||
<template #trigger>
|
||||
<span class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full border border-gray-400/70 text-[10px] font-semibold text-gray-400 transition-colors hover:border-primary-500 hover:text-primary-600 dark:border-gray-500 dark:text-gray-500 dark:hover:border-primary-400 dark:hover:text-primary-400">
|
||||
?
|
||||
</span>
|
||||
</template>
|
||||
<div class="space-y-1.5">
|
||||
<p class="font-medium">{{ t('admin.scheduledTests.maxResultsTooltipTitle') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipMeaning') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipBody') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipExample') }}</p>
|
||||
<p>{{ t('admin.scheduledTests.maxResultsTooltipRange') }}</p>
|
||||
</div>
|
||||
</HelpTooltip>
|
||||
</label>
|
||||
<Input
|
||||
v-model="editForm.max_results"
|
||||
@@ -227,6 +306,17 @@
|
||||
{{ t('admin.scheduledTests.enabled') }}
|
||||
</label>
|
||||
</div>
|
||||
<div class="flex items-end">
|
||||
<div>
|
||||
<label class="flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300">
|
||||
<Toggle v-model="editForm.auto_recover" />
|
||||
{{ t('admin.scheduledTests.autoRecover') }}
|
||||
</label>
|
||||
<p class="mt-0.5 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{ t('admin.scheduledTests.autoRecoverHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3 flex justify-end gap-2">
|
||||
<button
|
||||
@@ -377,6 +467,7 @@ import { ref, reactive, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import HelpTooltip from '@/components/common/HelpTooltip.vue'
|
||||
import Select, { type SelectOption } from '@/components/common/Select.vue'
|
||||
import Input from '@/components/common/Input.vue'
|
||||
import Toggle from '@/components/common/Toggle.vue'
|
||||
@@ -416,14 +507,16 @@ const editForm = reactive({
|
||||
model_id: '' as string,
|
||||
cron_expression: '' as string,
|
||||
max_results: '100' as string,
|
||||
enabled: true
|
||||
enabled: true,
|
||||
auto_recover: false
|
||||
})
|
||||
|
||||
const newPlan = reactive({
|
||||
model_id: '' as string,
|
||||
cron_expression: '' as string,
|
||||
max_results: '100' as string,
|
||||
enabled: true
|
||||
enabled: true,
|
||||
auto_recover: false
|
||||
})
|
||||
|
||||
const resetNewPlan = () => {
|
||||
@@ -431,6 +524,7 @@ const resetNewPlan = () => {
|
||||
newPlan.cron_expression = ''
|
||||
newPlan.max_results = '100'
|
||||
newPlan.enabled = true
|
||||
newPlan.auto_recover = false
|
||||
}
|
||||
|
||||
// Load plans when dialog opens
|
||||
@@ -472,7 +566,8 @@ const handleCreate = async () => {
|
||||
model_id: newPlan.model_id,
|
||||
cron_expression: newPlan.cron_expression,
|
||||
enabled: newPlan.enabled,
|
||||
max_results: maxResults
|
||||
max_results: maxResults,
|
||||
auto_recover: newPlan.auto_recover
|
||||
})
|
||||
appStore.showSuccess(t('admin.scheduledTests.createSuccess'))
|
||||
showAddForm.value = false
|
||||
@@ -504,6 +599,7 @@ const startEdit = (plan: ScheduledTestPlan) => {
|
||||
editForm.cron_expression = plan.cron_expression
|
||||
editForm.max_results = String(plan.max_results)
|
||||
editForm.enabled = plan.enabled
|
||||
editForm.auto_recover = plan.auto_recover
|
||||
}
|
||||
|
||||
const cancelEdit = () => {
|
||||
@@ -518,7 +614,8 @@ const handleEdit = async () => {
|
||||
model_id: editForm.model_id,
|
||||
cron_expression: editForm.cron_expression,
|
||||
max_results: Number(editForm.max_results) || 100,
|
||||
enabled: editForm.enabled
|
||||
enabled: editForm.enabled,
|
||||
auto_recover: editForm.auto_recover
|
||||
})
|
||||
const index = plans.value.findIndex((p) => p.id === editingPlanId.value)
|
||||
if (index !== -1) {
|
||||
|
||||
@@ -4,7 +4,15 @@
|
||||
<DataTable :columns="columns" :data="data" :loading="loading">
|
||||
<template #cell-user="{ row }">
|
||||
<div class="text-sm">
|
||||
<span class="font-medium text-gray-900 dark:text-white">{{ row.user?.email || '-' }}</span>
|
||||
<button
|
||||
v-if="row.user?.email"
|
||||
class="font-medium text-primary-600 underline decoration-dashed underline-offset-2 transition-colors hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
|
||||
@click="$emit('userClick', row.user_id, row.user?.email)"
|
||||
:title="t('admin.usage.clickToViewBalance')"
|
||||
>
|
||||
{{ row.user.email }}
|
||||
</button>
|
||||
<span v-else class="font-medium text-gray-900 dark:text-white">-</span>
|
||||
<span class="ml-1 text-gray-500 dark:text-gray-400">#{{ row.user_id }}</span>
|
||||
</div>
|
||||
</template>
|
||||
@@ -228,6 +236,14 @@
|
||||
<span class="text-gray-400">{{ t('admin.usage.outputCost') }}</span>
|
||||
<span class="font-medium text-white">${{ tooltipData.output_cost.toFixed(6) }}</span>
|
||||
</div>
|
||||
<div v-if="tooltipData && tooltipData.input_tokens > 0" class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.inputTokenPrice') }}</span>
|
||||
<span class="font-medium text-sky-300">{{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}</span>
|
||||
</div>
|
||||
<div v-if="tooltipData && tooltipData.output_tokens > 0" class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('usage.outputTokenPrice') }}</span>
|
||||
<span class="font-medium text-violet-300">{{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}</span>
|
||||
</div>
|
||||
<div v-if="tooltipData && tooltipData.cache_creation_cost > 0" class="flex items-center justify-between gap-4">
|
||||
<span class="text-gray-400">{{ t('admin.usage.cacheCreationCost') }}</span>
|
||||
<span class="font-medium text-white">${{ tooltipData.cache_creation_cost.toFixed(6) }}</span>
|
||||
@@ -238,6 +254,10 @@
|
||||
</div>
|
||||
</div>
|
||||
<!-- Rate and Summary -->
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.serviceTier') }}</span>
|
||||
<span class="font-semibold text-cyan-300">{{ getUsageServiceTierLabel(tooltipData?.service_tier, t) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.rate') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ (tooltipData?.rate_multiplier || 1).toFixed(2) }}x</span>
|
||||
@@ -271,6 +291,8 @@
|
||||
import { ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
|
||||
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
|
||||
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
|
||||
import { resolveUsageRequestType } from '@/utils/usageRequestType'
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
import EmptyState from '@/components/common/EmptyState.vue'
|
||||
@@ -278,6 +300,7 @@ import Icon from '@/components/icons/Icon.vue'
|
||||
import type { AdminUsageLog } from '@/types'
|
||||
|
||||
defineProps(['data', 'loading', 'columns'])
|
||||
defineEmits(['userClick'])
|
||||
const { t } = useI18n()
|
||||
|
||||
// Tooltip state - cost
|
||||
|
||||
111
frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts
Normal file
111
frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
import { describe, expect, it, vi, beforeEach } from 'vitest'
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import UsageTable from '../UsageTable.vue'
|
||||
|
||||
const messages: Record<string, string> = {
|
||||
'usage.costDetails': 'Cost Breakdown',
|
||||
'admin.usage.inputCost': 'Input Cost',
|
||||
'admin.usage.outputCost': 'Output Cost',
|
||||
'admin.usage.cacheCreationCost': 'Cache Creation Cost',
|
||||
'admin.usage.cacheReadCost': 'Cache Read Cost',
|
||||
'usage.inputTokenPrice': 'Input price',
|
||||
'usage.outputTokenPrice': 'Output price',
|
||||
'usage.perMillionTokens': '/ 1M tokens',
|
||||
'usage.serviceTier': 'Service tier',
|
||||
'usage.serviceTierPriority': 'Fast',
|
||||
'usage.serviceTierFlex': 'Flex',
|
||||
'usage.serviceTierStandard': 'Standard',
|
||||
'usage.rate': 'Rate',
|
||||
'usage.accountMultiplier': 'Account rate',
|
||||
'usage.original': 'Original',
|
||||
'usage.userBilled': 'User billed',
|
||||
'usage.accountBilled': 'Account billed',
|
||||
}
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => messages[key] ?? key,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
const DataTableStub = {
|
||||
props: ['data'],
|
||||
template: `
|
||||
<div>
|
||||
<div v-for="row in data" :key="row.request_id">
|
||||
<slot name="cell-cost" :row="row" />
|
||||
</div>
|
||||
</div>
|
||||
`,
|
||||
}
|
||||
|
||||
describe('admin UsageTable tooltip', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(HTMLElement.prototype, 'getBoundingClientRect').mockReturnValue({
|
||||
x: 0,
|
||||
y: 0,
|
||||
top: 20,
|
||||
left: 20,
|
||||
right: 120,
|
||||
bottom: 40,
|
||||
width: 100,
|
||||
height: 20,
|
||||
toJSON: () => ({}),
|
||||
} as DOMRect)
|
||||
})
|
||||
|
||||
it('shows service tier and billing breakdown in cost tooltip', async () => {
|
||||
const row = {
|
||||
request_id: 'req-admin-1',
|
||||
actual_cost: 0.092883,
|
||||
total_cost: 0.092883,
|
||||
account_rate_multiplier: 1,
|
||||
rate_multiplier: 1,
|
||||
service_tier: 'priority',
|
||||
input_cost: 0.020285,
|
||||
output_cost: 0.00303,
|
||||
cache_creation_cost: 0,
|
||||
cache_read_cost: 0.069568,
|
||||
input_tokens: 4057,
|
||||
output_tokens: 101,
|
||||
}
|
||||
|
||||
const wrapper = mount(UsageTable, {
|
||||
props: {
|
||||
data: [row],
|
||||
loading: false,
|
||||
columns: [],
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
DataTable: DataTableStub,
|
||||
EmptyState: true,
|
||||
Icon: true,
|
||||
Teleport: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
await wrapper.find('.group.relative').trigger('mouseenter')
|
||||
await nextTick()
|
||||
|
||||
const text = wrapper.text()
|
||||
expect(text).toContain('Service tier')
|
||||
expect(text).toContain('Fast')
|
||||
expect(text).toContain('Rate')
|
||||
expect(text).toContain('1.00x')
|
||||
expect(text).toContain('Account rate')
|
||||
expect(text).toContain('User billed')
|
||||
expect(text).toContain('Account billed')
|
||||
expect(text).toContain('$0.092883')
|
||||
expect(text).toContain('$5.0000 / 1M tokens')
|
||||
expect(text).toContain('$30.0000 / 1M tokens')
|
||||
expect(text).toContain('$0.069568')
|
||||
})
|
||||
})
|
||||
@@ -54,6 +54,7 @@
|
||||
/>
|
||||
<!-- Deposit button - matches menu style -->
|
||||
<button
|
||||
v-if="!hideActions"
|
||||
@click="emit('deposit')"
|
||||
class="flex items-center gap-2 rounded-lg border border-gray-200 bg-white px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-50 dark:border-dark-600 dark:bg-dark-800 dark:text-gray-300 dark:hover:bg-dark-700"
|
||||
>
|
||||
@@ -62,6 +63,7 @@
|
||||
</button>
|
||||
<!-- Withdraw button - matches menu style -->
|
||||
<button
|
||||
v-if="!hideActions"
|
||||
@click="emit('withdraw')"
|
||||
class="flex items-center gap-2 rounded-lg border border-gray-200 bg-white px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-50 dark:border-dark-600 dark:bg-dark-800 dark:text-gray-300 dark:hover:bg-dark-700"
|
||||
>
|
||||
@@ -176,7 +178,7 @@ import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
|
||||
const props = defineProps<{ show: boolean; user: AdminUser | null }>()
|
||||
const props = defineProps<{ show: boolean; user: AdminUser | null; hideActions?: boolean }>()
|
||||
const emit = defineEmits(['close', 'deposit', 'withdraw'])
|
||||
const { t } = useI18n()
|
||||
|
||||
|
||||
@@ -152,6 +152,7 @@
|
||||
v-else
|
||||
v-for="(row, index) in sortedData"
|
||||
:key="resolveRowKey(row, index)"
|
||||
:data-row-id="resolveRowKey(row, index)"
|
||||
class="hover:bg-gray-50 dark:hover:bg-dark-800"
|
||||
>
|
||||
<td
|
||||
|
||||
@@ -1,37 +1,56 @@
|
||||
<template>
|
||||
<div class="flex min-w-0 flex-1 items-center justify-between gap-2">
|
||||
<div class="flex min-w-0 flex-1 items-start justify-between gap-3">
|
||||
<!-- Left: name + description -->
|
||||
<div
|
||||
class="flex min-w-0 flex-1 flex-col items-start gap-1"
|
||||
class="flex min-w-0 flex-1 flex-col items-start"
|
||||
:title="description || undefined"
|
||||
>
|
||||
<!-- Row 1: platform badge (name bold) -->
|
||||
<GroupBadge
|
||||
:name="name"
|
||||
:platform="platform"
|
||||
:subscription-type="subscriptionType"
|
||||
:rate-multiplier="rateMultiplier"
|
||||
:user-rate-multiplier="userRateMultiplier"
|
||||
:show-rate="false"
|
||||
class="groupOptionItemBadge"
|
||||
/>
|
||||
<!-- Row 2: description with top spacing -->
|
||||
<span
|
||||
v-if="description"
|
||||
class="w-full truncate text-left text-xs text-gray-500 dark:text-gray-400"
|
||||
class="mt-1.5 w-full text-left text-xs leading-relaxed text-gray-500 dark:text-gray-400 line-clamp-2"
|
||||
>
|
||||
{{ description }}
|
||||
</span>
|
||||
</div>
|
||||
<svg
|
||||
v-if="showCheckmark && selected"
|
||||
class="h-4 w-4 shrink-0 text-primary-600 dark:text-primary-400"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
|
||||
</svg>
|
||||
|
||||
<!-- Right: rate pill + checkmark (vertically centered to first row) -->
|
||||
<div class="flex shrink-0 items-center gap-2 pt-0.5">
|
||||
<!-- Rate pill (platform color) -->
|
||||
<span v-if="rateMultiplier !== undefined" :class="['inline-flex items-center whitespace-nowrap rounded-full px-3 py-1 text-xs font-semibold', ratePillClass]">
|
||||
<template v-if="hasCustomRate">
|
||||
<span class="mr-1 line-through opacity-50">{{ rateMultiplier }}x</span>
|
||||
<span class="font-bold">{{ userRateMultiplier }}x</span>
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ rateMultiplier }}x 倍率
|
||||
</template>
|
||||
</span>
|
||||
<!-- Checkmark -->
|
||||
<svg
|
||||
v-if="showCheckmark && selected"
|
||||
class="h-4 w-4 shrink-0 text-primary-600 dark:text-primary-400"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import GroupBadge from './GroupBadge.vue'
|
||||
import type { SubscriptionType, GroupPlatform } from '@/types'
|
||||
|
||||
@@ -46,10 +65,43 @@ interface Props {
|
||||
showCheckmark?: boolean
|
||||
}
|
||||
|
||||
withDefaults(defineProps<Props>(), {
|
||||
const props = withDefaults(defineProps<Props>(), {
|
||||
subscriptionType: 'standard',
|
||||
selected: false,
|
||||
showCheckmark: true,
|
||||
userRateMultiplier: null
|
||||
})
|
||||
|
||||
// Whether user has a custom rate different from default
|
||||
const hasCustomRate = computed(() => {
|
||||
return (
|
||||
props.userRateMultiplier !== null &&
|
||||
props.userRateMultiplier !== undefined &&
|
||||
props.rateMultiplier !== undefined &&
|
||||
props.userRateMultiplier !== props.rateMultiplier
|
||||
)
|
||||
})
|
||||
|
||||
// Rate pill color matches platform badge color
|
||||
const ratePillClass = computed(() => {
|
||||
switch (props.platform) {
|
||||
case 'anthropic':
|
||||
return 'bg-amber-50 text-amber-700 dark:bg-amber-900/20 dark:text-amber-400'
|
||||
case 'openai':
|
||||
return 'bg-green-50 text-green-700 dark:bg-green-900/20 dark:text-green-400'
|
||||
case 'gemini':
|
||||
return 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400'
|
||||
case 'sora':
|
||||
return 'bg-rose-50 text-rose-700 dark:bg-rose-900/20 dark:text-rose-400'
|
||||
default: // antigravity and others
|
||||
return 'bg-violet-50 text-violet-700 dark:bg-violet-900/20 dark:text-violet-400'
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
/* Bold the group name inside GroupBadge when used in dropdown option */
|
||||
.groupOptionItemBadge :deep(span.truncate) {
|
||||
font-weight: 600;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -224,7 +224,13 @@ const filteredOptions = computed(() => {
|
||||
let opts = props.options as any[]
|
||||
if (props.searchable && searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
opts = opts.filter((opt) => getOptionLabel(opt).toLowerCase().includes(query))
|
||||
opts = opts.filter((opt) => {
|
||||
// Match label
|
||||
if (getOptionLabel(opt).toLowerCase().includes(query)) return true
|
||||
// Also match description if present
|
||||
if (opt.description && String(opt.description).toLowerCase().includes(query)) return true
|
||||
return false
|
||||
})
|
||||
}
|
||||
return opts
|
||||
})
|
||||
@@ -434,7 +440,7 @@ onUnmounted(() => {
|
||||
|
||||
<style>
|
||||
.select-dropdown-portal {
|
||||
@apply w-max min-w-[160px] max-w-[320px];
|
||||
@apply w-max min-w-[200px];
|
||||
@apply bg-white dark:bg-dark-800;
|
||||
@apply rounded-xl;
|
||||
@apply border border-gray-200 dark:border-dark-700;
|
||||
|
||||
410
frontend/src/composables/useSwipeSelect.ts
Normal file
410
frontend/src/composables/useSwipeSelect.ts
Normal file
@@ -0,0 +1,410 @@
|
||||
import { ref, onMounted, onUnmounted, type Ref } from 'vue'
|
||||
|
||||
/**
|
||||
* WeChat-style swipe/drag to select rows in a DataTable,
|
||||
* with a semi-transparent marquee overlay showing the selection area.
|
||||
*
|
||||
* Features:
|
||||
* - Start dragging inside the current table-page layout's non-text area
|
||||
* - Mouse wheel scrolling continues selecting new rows
|
||||
* - Auto-scroll when dragging near viewport edges
|
||||
* - 5px drag threshold to avoid accidental selection on click
|
||||
*
|
||||
* Usage:
|
||||
* const containerRef = ref<HTMLElement | null>(null)
|
||||
* useSwipeSelect(containerRef, {
|
||||
* isSelected: (id) => selIds.value.includes(id),
|
||||
* select: (id) => { if (!selIds.value.includes(id)) selIds.value.push(id) },
|
||||
* deselect: (id) => { selIds.value = selIds.value.filter(x => x !== id) },
|
||||
* })
|
||||
*
|
||||
* Wrap <DataTable> with <div ref="containerRef">...</div>
|
||||
* DataTable rows must have data-row-id attribute.
|
||||
*/
|
||||
export interface SwipeSelectAdapter {
|
||||
isSelected: (id: number) => boolean
|
||||
select: (id: number) => void
|
||||
deselect: (id: number) => void
|
||||
}
|
||||
|
||||
export function useSwipeSelect(
|
||||
containerRef: Ref<HTMLElement | null>,
|
||||
adapter: SwipeSelectAdapter
|
||||
) {
|
||||
const isDragging = ref(false)
|
||||
|
||||
let dragMode: 'select' | 'deselect' = 'select'
|
||||
let startRowIndex = -1
|
||||
let lastEndIndex = -1
|
||||
let startY = 0
|
||||
let lastMouseY = 0
|
||||
let pendingStartY = 0
|
||||
let initialSelectedSnapshot = new Map<number, boolean>()
|
||||
let cachedRows: HTMLElement[] = []
|
||||
let marqueeEl: HTMLDivElement | null = null
|
||||
let cachedScrollParent: HTMLElement | null = null
|
||||
|
||||
const DRAG_THRESHOLD = 5
|
||||
const SCROLL_ZONE = 60
|
||||
const SCROLL_SPEED = 8
|
||||
|
||||
function getActivationRoot(): HTMLElement | null {
|
||||
const container = containerRef.value
|
||||
if (!container) return null
|
||||
return container.closest('.table-page-layout') as HTMLElement | null || container
|
||||
}
|
||||
|
||||
function getDataRows(): HTMLElement[] {
|
||||
const container = containerRef.value
|
||||
if (!container) return []
|
||||
return Array.from(container.querySelectorAll('tbody tr[data-row-id]'))
|
||||
}
|
||||
|
||||
function getRowId(el: HTMLElement): number | null {
|
||||
const raw = el.getAttribute('data-row-id')
|
||||
if (raw === null) return null
|
||||
const id = Number(raw)
|
||||
return Number.isFinite(id) ? id : null
|
||||
}
|
||||
|
||||
/** Find the row index closest to a viewport Y coordinate (binary search). */
|
||||
function findRowIndexAtY(clientY: number): number {
|
||||
const len = cachedRows.length
|
||||
if (len === 0) return -1
|
||||
|
||||
// Boundary checks
|
||||
const firstRect = cachedRows[0].getBoundingClientRect()
|
||||
if (clientY < firstRect.top) return 0
|
||||
const lastRect = cachedRows[len - 1].getBoundingClientRect()
|
||||
if (clientY > lastRect.bottom) return len - 1
|
||||
|
||||
// Binary search — rows are vertically ordered
|
||||
let lo = 0, hi = len - 1
|
||||
while (lo <= hi) {
|
||||
const mid = (lo + hi) >>> 1
|
||||
const rect = cachedRows[mid].getBoundingClientRect()
|
||||
if (clientY < rect.top) hi = mid - 1
|
||||
else if (clientY > rect.bottom) lo = mid + 1
|
||||
else return mid
|
||||
}
|
||||
// In a gap between rows — pick the closer one
|
||||
if (hi < 0) return 0
|
||||
if (lo >= len) return len - 1
|
||||
const rHi = cachedRows[hi].getBoundingClientRect()
|
||||
const rLo = cachedRows[lo].getBoundingClientRect()
|
||||
return (clientY - rHi.bottom < rLo.top - clientY) ? hi : lo
|
||||
}
|
||||
|
||||
// --- Prevent text selection via selectstart (no body style mutation) ---
|
||||
function onSelectStart(e: Event) { e.preventDefault() }
|
||||
|
||||
// --- Marquee overlay ---
|
||||
function createMarquee() {
|
||||
removeMarquee() // defensive: remove any stale marquee
|
||||
marqueeEl = document.createElement('div')
|
||||
const isDark = document.documentElement.classList.contains('dark')
|
||||
Object.assign(marqueeEl.style, {
|
||||
position: 'fixed',
|
||||
background: isDark ? 'rgba(96, 165, 250, 0.15)' : 'rgba(59, 130, 246, 0.12)',
|
||||
border: isDark ? '1.5px solid rgba(96, 165, 250, 0.5)' : '1.5px solid rgba(59, 130, 246, 0.4)',
|
||||
borderRadius: '4px',
|
||||
pointerEvents: 'none',
|
||||
zIndex: '9999',
|
||||
transition: 'none',
|
||||
})
|
||||
document.body.appendChild(marqueeEl)
|
||||
}
|
||||
|
||||
function updateMarquee(currentY: number) {
|
||||
if (!marqueeEl || !containerRef.value) return
|
||||
const containerRect = containerRef.value.getBoundingClientRect()
|
||||
const top = Math.min(startY, currentY)
|
||||
const bottom = Math.max(startY, currentY)
|
||||
marqueeEl.style.left = containerRect.left + 'px'
|
||||
marqueeEl.style.width = containerRect.width + 'px'
|
||||
marqueeEl.style.top = top + 'px'
|
||||
marqueeEl.style.height = (bottom - top) + 'px'
|
||||
}
|
||||
|
||||
function removeMarquee() {
|
||||
if (marqueeEl) { marqueeEl.remove(); marqueeEl = null }
|
||||
}
|
||||
|
||||
// --- Row selection logic ---
|
||||
function applyRange(endIndex: number) {
|
||||
if (startRowIndex < 0 || endIndex < 0) return
|
||||
const rangeMin = Math.min(startRowIndex, endIndex)
|
||||
const rangeMax = Math.max(startRowIndex, endIndex)
|
||||
const prevMin = lastEndIndex >= 0 ? Math.min(startRowIndex, lastEndIndex) : rangeMin
|
||||
const prevMax = lastEndIndex >= 0 ? Math.max(startRowIndex, lastEndIndex) : rangeMax
|
||||
const lo = Math.min(rangeMin, prevMin)
|
||||
const hi = Math.max(rangeMax, prevMax)
|
||||
|
||||
for (let i = lo; i <= hi && i < cachedRows.length; i++) {
|
||||
const id = getRowId(cachedRows[i])
|
||||
if (id === null) continue
|
||||
if (i >= rangeMin && i <= rangeMax) {
|
||||
if (dragMode === 'select') adapter.select(id)
|
||||
else adapter.deselect(id)
|
||||
} else {
|
||||
const wasSelected = initialSelectedSnapshot.get(id) ?? false
|
||||
if (wasSelected) adapter.select(id)
|
||||
else adapter.deselect(id)
|
||||
}
|
||||
}
|
||||
lastEndIndex = endIndex
|
||||
}
|
||||
|
||||
// --- Scrollable parent ---
|
||||
function getScrollParent(el: HTMLElement): HTMLElement {
|
||||
let parent = el.parentElement
|
||||
while (parent && parent !== document.documentElement) {
|
||||
const { overflow, overflowY } = getComputedStyle(parent)
|
||||
if (/(auto|scroll)/.test(overflow + overflowY)) return parent
|
||||
parent = parent.parentElement
|
||||
}
|
||||
return document.documentElement
|
||||
}
|
||||
|
||||
// --- Scrollbar click detection ---
|
||||
/** Check if click lands on a scrollbar of the target element or any ancestor. */
|
||||
function isOnScrollbar(e: MouseEvent): boolean {
|
||||
let el = e.target as HTMLElement | null
|
||||
while (el && el !== document.documentElement) {
|
||||
const hasVScroll = el.scrollHeight > el.clientHeight
|
||||
const hasHScroll = el.scrollWidth > el.clientWidth
|
||||
if (hasVScroll || hasHScroll) {
|
||||
const rect = el.getBoundingClientRect()
|
||||
// clientWidth/clientHeight exclude scrollbar; offsetWidth/offsetHeight include it
|
||||
if (hasVScroll && e.clientX > rect.left + el.clientWidth) return true
|
||||
if (hasHScroll && e.clientY > rect.top + el.clientHeight) return true
|
||||
}
|
||||
el = el.parentElement
|
||||
}
|
||||
// Document-level scrollbar
|
||||
const docEl = document.documentElement
|
||||
if (e.clientX >= docEl.clientWidth || e.clientY >= docEl.clientHeight) return true
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* If the mousedown starts on inner cell content rather than cell padding,
|
||||
* prefer the browser's native text selection so users can copy text normally.
|
||||
*/
|
||||
function shouldPreferNativeTextSelection(target: HTMLElement): boolean {
|
||||
const row = target.closest('tbody tr[data-row-id]')
|
||||
if (!row) return false
|
||||
|
||||
const cell = target.closest('td, th')
|
||||
if (!cell) return false
|
||||
|
||||
return target !== cell && !target.closest('[data-swipe-select-handle]')
|
||||
}
|
||||
|
||||
function hasDirectTextContent(target: HTMLElement): boolean {
|
||||
return Array.from(target.childNodes).some(
|
||||
(node) => node.nodeType === Node.TEXT_NODE && (node.textContent?.trim().length ?? 0) > 0
|
||||
)
|
||||
}
|
||||
|
||||
function shouldPreferNativeSelectionOutsideRows(target: HTMLElement): boolean {
|
||||
const activationRoot = getActivationRoot()
|
||||
if (!activationRoot) return false
|
||||
if (!activationRoot.contains(target)) return false
|
||||
if (target.closest('tbody tr[data-row-id]')) return false
|
||||
|
||||
return hasDirectTextContent(target)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Phase 1: detect drag threshold (5px movement)
|
||||
// =============================================
|
||||
function onMouseDown(e: MouseEvent) {
|
||||
if (e.button !== 0) return
|
||||
if (!containerRef.value) return
|
||||
|
||||
const target = e.target as HTMLElement
|
||||
const activationRoot = getActivationRoot()
|
||||
if (!activationRoot || !activationRoot.contains(target)) return
|
||||
|
||||
// Skip clicks on any scrollbar (inner containers + document)
|
||||
if (isOnScrollbar(e)) return
|
||||
|
||||
if (target.closest('button, a, input, select, textarea, [role="button"], [role="menuitem"], [role="combobox"], [role="dialog"]')) return
|
||||
if (shouldPreferNativeTextSelection(target)) return
|
||||
if (shouldPreferNativeSelectionOutsideRows(target)) return
|
||||
|
||||
cachedRows = getDataRows()
|
||||
if (cachedRows.length === 0) return
|
||||
|
||||
pendingStartY = e.clientY
|
||||
// Prevent text selection as soon as the mouse is down,
|
||||
// before the drag threshold is reached (Phase 1).
|
||||
// Without this, the browser starts selecting text during
|
||||
// the 0–5px threshold movement window.
|
||||
document.addEventListener('selectstart', onSelectStart)
|
||||
document.addEventListener('mousemove', onThresholdMove)
|
||||
document.addEventListener('mouseup', onThresholdUp)
|
||||
}
|
||||
|
||||
function onThresholdMove(e: MouseEvent) {
|
||||
if (Math.abs(e.clientY - pendingStartY) < DRAG_THRESHOLD) return
|
||||
// Threshold exceeded — begin actual drag
|
||||
document.removeEventListener('mousemove', onThresholdMove)
|
||||
document.removeEventListener('mouseup', onThresholdUp)
|
||||
|
||||
beginDrag(pendingStartY)
|
||||
|
||||
// Process the move that crossed the threshold
|
||||
lastMouseY = e.clientY
|
||||
updateMarquee(e.clientY)
|
||||
const rowIdx = findRowIndexAtY(e.clientY)
|
||||
if (rowIdx >= 0) applyRange(rowIdx)
|
||||
autoScroll(e)
|
||||
|
||||
document.addEventListener('mousemove', onMouseMove)
|
||||
document.addEventListener('mouseup', onMouseUp)
|
||||
document.addEventListener('wheel', onWheel, { passive: true })
|
||||
}
|
||||
|
||||
function onThresholdUp() {
|
||||
document.removeEventListener('mousemove', onThresholdMove)
|
||||
document.removeEventListener('mouseup', onThresholdUp)
|
||||
// Phase 1 ended without crossing threshold — remove selectstart blocker
|
||||
document.removeEventListener('selectstart', onSelectStart)
|
||||
cachedRows = []
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Phase 2: actual drag session
|
||||
// ============================
|
||||
function beginDrag(clientY: number) {
|
||||
startRowIndex = findRowIndexAtY(clientY)
|
||||
const startRowId = startRowIndex >= 0 ? getRowId(cachedRows[startRowIndex]) : null
|
||||
dragMode = (startRowId !== null && adapter.isSelected(startRowId)) ? 'deselect' : 'select'
|
||||
|
||||
initialSelectedSnapshot = new Map()
|
||||
for (const row of cachedRows) {
|
||||
const id = getRowId(row)
|
||||
if (id !== null) initialSelectedSnapshot.set(id, adapter.isSelected(id))
|
||||
}
|
||||
|
||||
isDragging.value = true
|
||||
startY = clientY
|
||||
lastMouseY = clientY
|
||||
lastEndIndex = -1
|
||||
cachedScrollParent = cachedRows.length > 0
|
||||
? getScrollParent(cachedRows[0])
|
||||
: (containerRef.value ? getScrollParent(containerRef.value) : null)
|
||||
|
||||
createMarquee()
|
||||
updateMarquee(clientY)
|
||||
applyRange(startRowIndex)
|
||||
// selectstart is already blocked since Phase 1 (onMouseDown).
|
||||
// Clear any text selection that the browser may have started
|
||||
// before our selectstart handler took effect.
|
||||
window.getSelection()?.removeAllRanges()
|
||||
}
|
||||
|
||||
function onMouseMove(e: MouseEvent) {
|
||||
if (!isDragging.value) return
|
||||
lastMouseY = e.clientY
|
||||
updateMarquee(e.clientY)
|
||||
const rowIdx = findRowIndexAtY(e.clientY)
|
||||
if (rowIdx >= 0 && rowIdx !== lastEndIndex) applyRange(rowIdx)
|
||||
autoScroll(e)
|
||||
}
|
||||
|
||||
function onWheel() {
|
||||
if (!isDragging.value) return
|
||||
// After wheel scroll, rows shift in viewport — re-check selection
|
||||
requestAnimationFrame(() => {
|
||||
if (!isDragging.value) return // guard: drag may have ended before this frame
|
||||
const rowIdx = findRowIndexAtY(lastMouseY)
|
||||
if (rowIdx >= 0) applyRange(rowIdx)
|
||||
})
|
||||
}
|
||||
|
||||
function cleanupDrag() {
|
||||
isDragging.value = false
|
||||
startRowIndex = -1
|
||||
lastEndIndex = -1
|
||||
cachedRows = []
|
||||
initialSelectedSnapshot.clear()
|
||||
cachedScrollParent = null
|
||||
stopAutoScroll()
|
||||
removeMarquee()
|
||||
document.removeEventListener('selectstart', onSelectStart)
|
||||
document.removeEventListener('mousemove', onMouseMove)
|
||||
document.removeEventListener('mouseup', onMouseUp)
|
||||
document.removeEventListener('wheel', onWheel)
|
||||
}
|
||||
|
||||
function onMouseUp() {
|
||||
cleanupDrag()
|
||||
}
|
||||
|
||||
// Guard: clean up if mouse leaves window or window loses focus during drag
|
||||
function onWindowBlur() {
|
||||
if (isDragging.value) cleanupDrag()
|
||||
// Also clean up threshold phase (Phase 1)
|
||||
document.removeEventListener('mousemove', onThresholdMove)
|
||||
document.removeEventListener('mouseup', onThresholdUp)
|
||||
document.removeEventListener('selectstart', onSelectStart)
|
||||
}
|
||||
|
||||
// --- Auto-scroll logic ---
|
||||
let scrollRAF = 0
|
||||
|
||||
function autoScroll(e: MouseEvent) {
|
||||
cancelAnimationFrame(scrollRAF)
|
||||
const scrollEl = cachedScrollParent
|
||||
if (!scrollEl) return
|
||||
|
||||
let dy = 0
|
||||
if (scrollEl === document.documentElement) {
|
||||
if (e.clientY < SCROLL_ZONE) dy = -SCROLL_SPEED
|
||||
else if (e.clientY > window.innerHeight - SCROLL_ZONE) dy = SCROLL_SPEED
|
||||
} else {
|
||||
const rect = scrollEl.getBoundingClientRect()
|
||||
if (e.clientY < rect.top + SCROLL_ZONE) dy = -SCROLL_SPEED
|
||||
else if (e.clientY > rect.bottom - SCROLL_ZONE) dy = SCROLL_SPEED
|
||||
}
|
||||
|
||||
if (dy !== 0) {
|
||||
const step = () => {
|
||||
const prevScrollTop = scrollEl.scrollTop
|
||||
scrollEl.scrollTop += dy
|
||||
// Only re-check selection if scroll actually moved
|
||||
if (scrollEl.scrollTop !== prevScrollTop) {
|
||||
const rowIdx = findRowIndexAtY(lastMouseY)
|
||||
if (rowIdx >= 0 && rowIdx !== lastEndIndex) applyRange(rowIdx)
|
||||
}
|
||||
scrollRAF = requestAnimationFrame(step)
|
||||
}
|
||||
scrollRAF = requestAnimationFrame(step)
|
||||
}
|
||||
}
|
||||
|
||||
function stopAutoScroll() {
|
||||
cancelAnimationFrame(scrollRAF)
|
||||
}
|
||||
|
||||
// --- Lifecycle ---
|
||||
onMounted(() => {
|
||||
document.addEventListener('mousedown', onMouseDown)
|
||||
window.addEventListener('blur', onWindowBlur)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
document.removeEventListener('mousedown', onMouseDown)
|
||||
window.removeEventListener('blur', onWindowBlur)
|
||||
// Clean up any in-progress drag state
|
||||
document.removeEventListener('mousemove', onThresholdMove)
|
||||
document.removeEventListener('mouseup', onThresholdUp)
|
||||
document.removeEventListener('selectstart', onSelectStart)
|
||||
cleanupDrag()
|
||||
})
|
||||
|
||||
return { isDragging }
|
||||
}
|
||||
98
frontend/src/composables/useTableSelection.ts
Normal file
98
frontend/src/composables/useTableSelection.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { computed, ref, type Ref } from 'vue'
|
||||
|
||||
interface UseTableSelectionOptions<T> {
|
||||
rows: Ref<T[]>
|
||||
getId: (row: T) => number
|
||||
}
|
||||
|
||||
export function useTableSelection<T>({ rows, getId }: UseTableSelectionOptions<T>) {
|
||||
const selectedSet = ref<Set<number>>(new Set())
|
||||
|
||||
const selectedIds = computed(() => Array.from(selectedSet.value))
|
||||
const selectedCount = computed(() => selectedSet.value.size)
|
||||
|
||||
const isSelected = (id: number) => selectedSet.value.has(id)
|
||||
|
||||
const replaceSelectedSet = (next: Set<number>) => {
|
||||
selectedSet.value = next
|
||||
}
|
||||
|
||||
const setSelectedIds = (ids: number[]) => {
|
||||
selectedSet.value = new Set(ids)
|
||||
}
|
||||
|
||||
const select = (id: number) => {
|
||||
if (selectedSet.value.has(id)) return
|
||||
const next = new Set(selectedSet.value)
|
||||
next.add(id)
|
||||
replaceSelectedSet(next)
|
||||
}
|
||||
|
||||
const deselect = (id: number) => {
|
||||
if (!selectedSet.value.has(id)) return
|
||||
const next = new Set(selectedSet.value)
|
||||
next.delete(id)
|
||||
replaceSelectedSet(next)
|
||||
}
|
||||
|
||||
const toggle = (id: number) => {
|
||||
if (selectedSet.value.has(id)) {
|
||||
deselect(id)
|
||||
return
|
||||
}
|
||||
select(id)
|
||||
}
|
||||
|
||||
const clear = () => {
|
||||
if (selectedSet.value.size === 0) return
|
||||
replaceSelectedSet(new Set())
|
||||
}
|
||||
|
||||
const removeMany = (ids: number[]) => {
|
||||
if (ids.length === 0 || selectedSet.value.size === 0) return
|
||||
const next = new Set(selectedSet.value)
|
||||
let changed = false
|
||||
ids.forEach((id) => {
|
||||
if (next.delete(id)) changed = true
|
||||
})
|
||||
if (changed) replaceSelectedSet(next)
|
||||
}
|
||||
|
||||
const allVisibleSelected = computed(() => {
|
||||
if (rows.value.length === 0) return false
|
||||
return rows.value.every((row) => selectedSet.value.has(getId(row)))
|
||||
})
|
||||
|
||||
const toggleVisible = (checked: boolean) => {
|
||||
const next = new Set(selectedSet.value)
|
||||
rows.value.forEach((row) => {
|
||||
const id = getId(row)
|
||||
if (checked) {
|
||||
next.add(id)
|
||||
} else {
|
||||
next.delete(id)
|
||||
}
|
||||
})
|
||||
replaceSelectedSet(next)
|
||||
}
|
||||
|
||||
const selectVisible = () => {
|
||||
toggleVisible(true)
|
||||
}
|
||||
|
||||
return {
|
||||
selectedSet,
|
||||
selectedIds,
|
||||
selectedCount,
|
||||
allVisibleSelected,
|
||||
isSelected,
|
||||
setSelectedIds,
|
||||
select,
|
||||
deselect,
|
||||
toggle,
|
||||
clear,
|
||||
removeMany,
|
||||
toggleVisible,
|
||||
selectVisible
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user