diff --git a/README.md b/README.md index c83bd27e..4a7bde8e 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - **Concurrency Control** - Per-user and per-account concurrency limits - **Rate Limiting** - Configurable request and token rate limits - **Admin Dashboard** - Web interface for monitoring and management +- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard + +## Ecosystem + +Community projects that extend or integrate with Sub2API: + +| Project | Description | Features | +|---------|-------------|----------| +| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe | +| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native | ## Tech Stack diff --git a/README_CN.md b/README_CN.md index a5ad8a94..eee89b07 100644 --- a/README_CN.md +++ b/README_CN.md @@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - **并发控制** - 用户级和账号级并发限制 - **速率限制** - 可配置的请求和 Token 速率限制 - **管理后台** - Web 界面进行监控和管理 +- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能 + +## 生态项目 + +围绕 Sub2API 的社区扩展与集成项目: + +| 项目 | 说明 | 功能 | +|------|------|------| +| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 | +| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 | ## 技术栈 diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 36d043b5..c51046a2 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -27,12 +27,11 @@ const ( // Account type constants const ( - AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = "apikey" // API Key类型账号 - AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) - AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) - AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) + AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index c7ca0ca2..3ef213e1 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -97,7 +97,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -116,7 +116,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // Handle OpenAI accounts if account.IsOpenAI() { - // For OAuth accounts: return default OpenAI models - if account.IsOAuth() { + // OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。 + if account.IsOpenAIPassthroughEnabled() { response.Success(c, openai.DefaultModels) return } - // For API Key accounts: check model_mapping mapping := account.GetModelMapping() if len(mapping) == 0 { response.Success(c, openai.DefaultModels) diff --git a/backend/internal/handler/admin/account_handler_available_models_test.go b/backend/internal/handler/admin/account_handler_available_models_test.go new file mode 100644 index 00000000..c5f1e2d8 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_available_models_test.go @@ -0,0 +1,105 @@ +package admin + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type availableModelsAdminService struct { + *stubAdminService + account service.Account +} + +func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) { + if s.account.ID == id { + acc := s.account + return &acc, nil + } + return s.stubAdminService.GetAccount(context.Background(), id) +} + +func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels) + return router +} + +func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 42, + Name: "openai-oauth", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Data, 1) + require.Equal(t, "gpt-5", resp.Data[0].ID) +} + +func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 43, + Name: "openai-oauth-passthrough", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "openai_passthrough": true, + }, + }, + } + router := setupAvailableModelsRouter(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp.Data) + require.NotEqual(t, "gpt-5", resp.Data[0].ID) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 8330868d..ff76edda 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: settings.MinClaudeCodeVersion, AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, + BackendModeEnabled: settings.BackendModeEnabled, }) } @@ -199,6 +200,9 @@ type UpdateSettingsRequest struct { // 分组隔离 AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` + + // Backend Mode + BackendModeEnabled bool `json:"backend_mode_enabled"` } // UpdateSettings 更新系统设置 @@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { IdentityPatchPrompt: req.IdentityPatchPrompt, MinClaudeCodeVersion: req.MinClaudeCodeVersion, AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, + BackendModeEnabled: req.BackendModeEnabled, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, + BackendModeEnabled: updatedSettings.BackendModeEnabled, }) } @@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { changed = append(changed, "allow_ungrouped_key_scheduling") } + if before.BackendModeEnabled != after.BackendModeEnabled { + changed = append(changed, "backend_mode_enabled") + } if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { changed = append(changed, "purchase_subscription_enabled") } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 1ffa9d71..3b257189 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) { return } + // Backend mode: only admin can login + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + h.respondWithTokenPair(c, user) } @@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Delete the login session - _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) - - // Get the user + // Get the user (before session deletion so we can check backend mode) user, err := h.userService.GetByID(c.Request.Context(), session.UserID) if err != nil { response.ErrorFrom(c, err) return } + // Backend mode: only admin can login (check BEFORE deleting session) + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + + // Delete the login session (only after all checks pass) + _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + h.respondWithTokenPair(c, user) } @@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) { return } - tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) + result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) if err != nil { response.ErrorFrom(c, err) return } + // Backend mode: block non-admin token refresh + if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" { + response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + return + } + response.Success(c, RefreshTokenResponse{ - AccessToken: tokenPair.AccessToken, - RefreshToken: tokenPair.RefreshToken, - ExpiresIn: tokenPair.ExpiresIn, + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresIn: result.ExpiresIn, TokenType: "Bearer", }) } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 205ccd65..3706f725 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account { } } - // 提取 API Key 账号配额限制(仅 apikey 类型有效) - if a.Type == service.AccountTypeAPIKey { + // 提取账号配额限制(apikey / bedrock 类型有效) + if a.IsAPIKeyOrBedrock() { if limit := a.GetQuotaLimit(); limit > 0 { out.QuotaLimit = &limit used := a.GetQuotaUsed() @@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account { used := a.GetQuotaWeeklyUsed() out.QuotaWeeklyUsed = &used } + // 固定时间重置配置 + if mode := a.GetQuotaDailyResetMode(); mode == "fixed" { + out.QuotaDailyResetMode = &mode + hour := a.GetQuotaDailyResetHour() + out.QuotaDailyResetHour = &hour + } + if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" { + out.QuotaWeeklyResetMode = &mode + day := a.GetQuotaWeeklyResetDay() + out.QuotaWeeklyResetDay = &day + hour := a.GetQuotaWeeklyResetHour() + out.QuotaWeeklyResetHour = &hour + } + if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" { + tz := a.GetQuotaResetTimezone() + out.QuotaResetTimezone = &tz + } + if a.Extra != nil { + if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" { + out.QuotaDailyResetAt = &v + } + if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" { + out.QuotaWeeklyResetAt = &v + } + } } return out diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 8a1bba5d..3df54fe9 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -81,6 +81,9 @@ type SystemSettings struct { // 分组隔离 AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` + + // Backend Mode + BackendModeEnabled bool `json:"backend_mode_enabled"` } type DefaultSubscriptionSetting struct { @@ -111,6 +114,7 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index d9ccda2d..3708eed5 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -203,6 +203,16 @@ type Account struct { QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"` QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"` + // 配额固定时间重置配置 + QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"` + QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"` + QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"` + QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"` + QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"` + QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"` + QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` + QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 6900e7cd..724376e3 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := "" - if apiKey.Group != nil { - defaultMappedModel = apiKey.Group.DefaultMappedModel - } - if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" { - defaultMappedModel = fallbackModel - } + defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index d23c7efe..87b0d0d6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -655,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := "" - if apiKey.Group != nil { - defaultMappedModel = apiKey.Group.DefaultMappedModel - } - // 如果使用了降级模型调度,强制使用降级模型 - if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" { - defaultMappedModel = fallbackModel - } + // 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, + // 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。 + defaultMappedModel := c.GetString("openai_messages_fallback_model") result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1188d55e..92061895 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, + BackendModeEnabled: settings.BackendModeEnabled, Version: h.version, }) } diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index d46bbc45..1e63315b 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -19,6 +19,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) +// ForbiddenError 表示上游返回 403 Forbidden +type ForbiddenError struct { + StatusCode int + Body string +} + +func (e *ForbiddenError) Error() string { + return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body) +} + // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 @@ -514,7 +524,20 @@ type ModelQuotaInfo struct { // ModelInfo 模型信息 type ModelInfo struct { - QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + DisplayName string `json:"displayName,omitempty"` + SupportsImages *bool `json:"supportsImages,omitempty"` + SupportsThinking *bool `json:"supportsThinking,omitempty"` + ThinkingBudget *int `json:"thinkingBudget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"maxTokens,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"` +} + +// DeprecatedModelInfo 废弃模型转发信息 +type DeprecatedModelInfo struct { + NewModelID string `json:"newModelId"` } // FetchAvailableModelsRequest fetchAvailableModels 请求 @@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct { // FetchAvailableModelsResponse fetchAvailableModels 响应 type FetchAvailableModelsResponse struct { - Models map[string]ModelInfo `json:"models"` + Models map[string]ModelInfo `json:"models"` + DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"` } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON @@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI continue } + if resp.StatusCode == http.StatusForbidden { + return nil, nil, &ForbiddenError{ + StatusCode: resp.StatusCode, + Body: string(respBodyBytes), + } + } + if resp.StatusCode != http.StatusOK { return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) } diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 1c1d39bb..2db65572 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) { assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "function_call", items[2].Type) assert.Equal(t, "fc_call_1", items[2].CallID) + assert.Empty(t, items[2].ID) assert.Equal(t, "function_call_output", items[3].Type) assert.Equal(t, "fc_call_1", items[3].CallID) assert.Equal(t, "Sunny, 72°F", items[3].Output) diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index 592bec39..0a747869 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e CallID: fcID, Name: b.Name, Arguments: args, - ID: fcID, }) } diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 71b7a6f5..8b819033 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) { // Check function_call item assert.Equal(t, "function_call", items[1].Type) assert.Equal(t, "call_1", items[1].CallID) + assert.Empty(t, items[1].ID) assert.Equal(t, "ping", items[1].Name) // Check function_call_output item @@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) assert.Equal(t, "user", items[0].Role) assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "function_call", items[2].Type) + assert.Empty(t, items[2].ID) +} + +func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "assistant", items[1].Role) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Equal(t, "AB", parts[0].Text) +} + +func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Contains(t, parts[0].Text, "internal plan") + assert.Contains(t, parts[0].Text, "final answer") } // --------------------------------------------------------------------------- @@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) { var content string require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) - // Reasoning summary is prepended to text - assert.Equal(t, "I thought about it.The answer is 42.", content) + assert.Equal(t, "The answer is 42.", content) + assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent) } func TestResponsesToChatCompletions_Incomplete(t *testing.T) { @@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) { Delta: "Thinking...", }, state) require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent) + assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent) + + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.done", + }, state) + require.Len(t, chunks, 0) +} + +func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "plan", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent) + assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent) + + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "answer", + }, state) + require.Len(t, chunks, 1) require.NotNil(t, chunks[0].Choices[0].Delta.Content) - assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content) } func TestFinalizeResponsesChatStream(t *testing.T) { diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index 37285b09..c4a9e773 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -3,6 +3,7 @@ package apicompat import ( "encoding/json" "fmt" + "strings" ) // ChatCompletionsToResponses converts a Chat Completions request into a @@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { // Emit assistant message with output_text if content is non-empty. if len(m.Content) > 0 { - var s string - if err := json.Unmarshal(m.Content, &s); err == nil && s != "" { + s, err := parseAssistantContent(m.Content) + if err != nil { + return nil, err + } + if s != "" { parts := []ResponsesContentPart{{Type: "output_text", Text: s}} partsJSON, err := json.Marshal(parts) if err != nil { @@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { CallID: tc.ID, Name: tc.Function.Name, Arguments: args, - ID: tc.ID, }) } return items, nil } +// parseAssistantContent returns assistant content as plain text. +// +// Supported formats: +// - JSON string +// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}]) +// +// For structured thinking/reasoning parts, it preserves semantics by wrapping +// the text in explicit tags so downstream can still distinguish it from normal text. +func parseAssistantContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, nil + } + + var parts []map[string]any + if err := json.Unmarshal(raw, &parts); err != nil { + // Keep compatibility with prior behavior: unsupported assistant content + // formats are ignored instead of failing the whole request conversion. + return "", nil + } + + var b strings.Builder + write := func(v string) error { + _, err := b.WriteString(v) + return err + } + for _, p := range parts { + typ, _ := p["type"].(string) + text, _ := p["text"].(string) + thinking, _ := p["thinking"].(string) + + switch typ { + case "thinking", "reasoning": + if thinking != "" { + if err := write(""); err != nil { + return "", err + } + if err := write(thinking); err != nil { + return "", err + } + if err := write(""); err != nil { + return "", err + } + } else if text != "" { + if err := write(""); err != nil { + return "", err + } + if err := write(text); err != nil { + return "", err + } + if err := write(""); err != nil { + return "", err + } + } + default: + if text != "" { + if err := write(text); err != nil { + return "", err + } + } + } + } + + return b.String(), nil +} + // chatToolToResponses converts a tool result message (role=tool) into a // function_call_output item. func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 8f83bce4..688a68eb 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp } var contentText string + var reasoningText string var toolCalls []ChatToolCall for _, item := range resp.Output { @@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp case "reasoning": for _, s := range item.Summary { if s.Type == "summary_text" && s.Text != "" { - contentText += s.Text + reasoningText += s.Text } } case "web_search_call": @@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp raw, _ := json.Marshal(contentText) msg.Content = raw } + if reasoningText != "" { + msg.ReasoningContent = reasoningText + } finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) @@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent return resToChatHandleFuncArgsDelta(evt, state) case "response.reasoning_summary_text.delta": return resToChatHandleReasoningDelta(evt, state) + case "response.reasoning_summary_text.done": + return nil case "response.completed", "response.incomplete", "response.failed": return resToChatHandleCompleted(evt, state) default: @@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv if evt.Delta == "" { return nil } - content := evt.Delta - return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} + reasoning := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})} } func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index eb77d89f..b724a5ed 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -361,11 +361,12 @@ type ChatStreamOptions struct { // ChatMessage is a single message in the Chat Completions conversation. type ChatMessage struct { - Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" + Content json.RawMessage `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` // Legacy function calling FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` @@ -466,9 +467,10 @@ type ChatChunkChoice struct { // ChatDelta carries incremental content in a streaming chunk. type ChatDelta struct { - Role string `json:"role,omitempty"` - Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters - ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` } // --------------------------------------------------------------------------- diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index a9cb2cba..20ff7373 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } - if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { - r.syncSchedulerAccountSnapshot(ctx, account.ID) - } + // 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照, + // 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。 + r.syncSchedulerAccountSnapshot(ctx, account.ID) return nil } @@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va // nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string. const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')` +// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired. +// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes. +const dailyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + END +)` + +// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired. +const weeklyExpiredExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz) + ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + END +)` + +// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs. +// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour. +// This correctly handles long-inactive accounts by jumping directly to the next valid reset point. +const nextDailyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed' + THEN to_char(( + -- Compute today's reset point in the configured timezone, then pick next future one + CASE WHEN NOW() >= ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- NOW() is at or past today's reset point → next reset is tomorrow + THEN ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + + '1 day'::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- NOW() is before today's reset point → next reset is today + ELSE ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + END + ) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + +// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs. +// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour. +// This correctly handles long-inactive accounts by jumping directly to the next valid reset point. +const nextWeeklyResetAtExpr = `( + CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed' + THEN to_char(( + -- Compute this week's reset point in the configured timezone + -- Step 1: get today's date at reset hour in configured tz + -- Step 2: compute days forward to target weekday + -- Step 3: if same day but past reset hour, advance 7 days + CASE + WHEN ( + -- days_forward = (target_day - current_day + 7) % 7 + (COALESCE((extra->>'quota_weekly_reset_day')::int, 1) + - EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int + + 7) % 7 + ) = 0 AND NOW() >= ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + -- Same weekday and past reset hour → next week + THEN ( + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + + '7 days'::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + ELSE ( + -- Advance to target weekday this week (or next if days_forward > 0) + date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')) + + (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval + + (( + (COALESCE((extra->>'quota_weekly_reset_day')::int, 1) + - EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int + + 7) % 7 + ) || ' days')::interval + ) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC') + END + ) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + ELSE NULL END +)` + // IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度) // 日/周额度在周期过期时自动重置为 0 再递增。 +// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。 func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { rows, err := r.sql.QueryContext(ctx, `UPDATE accounts SET extra = ( @@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_daily_used', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, 'quota_daily_start', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END -- 周额度:仅在 quota_weekly_limit > 0 时处理 || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_weekly_used', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, 'quota_weekly_start', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END ) + -- 固定模式重置时更新下次重置时间 + || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END ), updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL @@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am } // ResetQuotaUsed 重置账号所有维度的配额用量为 0 +// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间 func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { _, err := r.sql.ExecContext(ctx, `UPDATE accounts SET extra = ( COALESCE(extra, '{}'::jsonb) || '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb - ) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW() + ) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL`, id) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 29b699e6..e697802e 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() { s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status) } +func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "sync-credentials-update", + Status: service.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.1", + }, + }, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.2", + }, + } + err := s.repo.Update(s.ctx, account) + s.Require().NoError(err, "Update") + + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any) + s.Require().True(ok) + s.Require().Equal("gpt-5.2", mapping["gpt-5"]) +} + func (s *AccountRepoSuite) TestDelete() { account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"}) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d46e0624..fa298b2b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) { "purchase_subscription_url": "", "min_claude_code_version": "", "allow_ungrouped_key_scheduling": false, + "backend_mode_enabled": false, "custom_menu_items": [] } }`, diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go new file mode 100644 index 00000000..46482af3 --- /dev/null +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled. +// Must be placed AFTER JWT auth middleware so that the user role is available in context. +func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc { + return func(c *gin.Context) { + if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { + c.Next() + return + } + role, _ := GetUserRoleFromContext(c) + if role == "admin" { + c.Next() + return + } + response.Forbidden(c, "Backend mode is active. User self-service is disabled.") + c.Abort() + } +} + +// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. +// Allows: login, login/2fa, logout, refresh (admin needs these). +// Blocks: register, forgot-password, reset-password, OAuth, etc. +func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { + return func(c *gin.Context) { + if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { + c.Next() + return + } + path := c.Request.URL.Path + // Allow login, 2FA, logout, refresh, public settings + allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} + for _, suffix := range allowedSuffixes { + if strings.HasSuffix(path, suffix) { + c.Next() + return + } + } + response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") + c.Abort() + } +} diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go new file mode 100644 index 00000000..8878ebc9 --- /dev/null +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -0,0 +1,239 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type bmSettingRepo struct { + values map[string]string +} + +func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) { + v, ok := r.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return v, nil +} + +func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error { + panic("unexpected Set call") +} + +func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + if r.values == nil { + r.values = make(map[string]string, len(settings)) + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (r *bmSettingRepo) Delete(_ context.Context, _ string) error { + panic("unexpected Delete call") +} + +func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService { + t.Helper() + + repo := &bmSettingRepo{ + values: map[string]string{ + service.SettingKeyBackendModeEnabled: enabled, + }, + } + svc := service.NewSettingService(repo, &config.Config{}) + require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{ + BackendModeEnabled: enabled == "true", + })) + + return svc +} + +func stringPtr(v string) *string { + return &v +} + +func TestBackendModeUserGuard(t *testing.T) { + tests := []struct { + name string + nilService bool + enabled string + role *string + wantStatus int + }{ + { + name: "disabled_allows_all", + enabled: "false", + role: stringPtr("user"), + wantStatus: http.StatusOK, + }, + { + name: "nil_service_allows_all", + nilService: true, + role: stringPtr("user"), + wantStatus: http.StatusOK, + }, + { + name: "enabled_admin_allowed", + enabled: "true", + role: stringPtr("admin"), + wantStatus: http.StatusOK, + }, + { + name: "enabled_user_blocked", + enabled: "true", + role: stringPtr("user"), + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_no_role_blocked", + enabled: "true", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_empty_role_blocked", + enabled: "true", + role: stringPtr(""), + wantStatus: http.StatusForbidden, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + if tc.role != nil { + role := *tc.role + r.Use(func(c *gin.Context) { + c.Set(string(ContextKeyUserRole), role) + c.Next() + }) + } + + var svc *service.SettingService + if !tc.nilService { + svc = newBackendModeSettingService(t, tc.enabled) + } + + r.Use(BackendModeUserGuard(svc)) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + r.ServeHTTP(w, req) + + require.Equal(t, tc.wantStatus, w.Code) + }) + } +} + +func TestBackendModeAuthGuard(t *testing.T) { + tests := []struct { + name string + nilService bool + enabled string + path string + wantStatus int + }{ + { + name: "disabled_allows_all", + enabled: "false", + path: "/api/v1/auth/register", + wantStatus: http.StatusOK, + }, + { + name: "nil_service_allows_all", + nilService: true, + path: "/api/v1/auth/register", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_login", + enabled: "true", + path: "/api/v1/auth/login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_login_2fa", + enabled: "true", + path: "/api/v1/auth/login/2fa", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_logout", + enabled: "true", + path: "/api/v1/auth/logout", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_refresh", + enabled: "true", + path: "/api/v1/auth/refresh", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_register", + enabled: "true", + path: "/api/v1/auth/register", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_blocks_forgot_password", + enabled: "true", + path: "/api/v1/auth/forgot-password", + wantStatus: http.StatusForbidden, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + + var svc *service.SettingService + if !tc.nilService { + svc = newBackendModeSettingService(t, tc.enabled) + } + + r.Use(BackendModeAuthGuard(svc)) + r.Any("/*path", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, tc.path, nil) + r.ServeHTTP(w, req) + + require.Equal(t, tc.wantStatus, w.Code) + }) + } +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 571986b4..99701531 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -107,9 +107,9 @@ func registerRoutes( v1 := r.Group("/api/v1") // 注册各模块路由 - routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) - routes.RegisterUserRoutes(v1, h, jwtAuth) - routes.RegisterSoraClientRoutes(v1, h, jwtAuth) + routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) + routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 0efc9560..a6c0ecf5 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -6,6 +6,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/middleware" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" @@ -17,12 +18,14 @@ func RegisterAuthRoutes( h *handler.Handlers, jwtAuth servermiddleware.JWTAuthMiddleware, redisClient *redis.Client, + settingService *service.SettingService, ) { // 创建速率限制器 rateLimiter := middleware.NewRateLimiter(redisClient) // 公开接口 auth := v1.Group("/auth") + auth.Use(servermiddleware.BackendModeAuthGuard(settingService)) { // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ @@ -78,6 +81,7 @@ func RegisterAuthRoutes( // 需要认证的当前用户信息 authenticated := v1.Group("") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(servermiddleware.BackendModeUserGuard(settingService)) { authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go index 5ce8497c..4f411cec 100644 --- a/backend/internal/server/routes/auth_rate_limit_test.go +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { c.Next() }), redisClient, + nil, ) return router diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go index 40ae0436..13fceb81 100644 --- a/backend/internal/server/routes/sora_client.go +++ b/backend/internal/server/routes/sora_client.go @@ -3,6 +3,7 @@ package routes import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -12,6 +13,7 @@ func RegisterSoraClientRoutes( v1 *gin.RouterGroup, h *handler.Handlers, jwtAuth middleware.JWTAuthMiddleware, + settingService *service.SettingService, ) { if h.SoraClient == nil { return @@ -19,6 +21,7 @@ func RegisterSoraClientRoutes( authenticated := v1.Group("/sora") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(middleware.BackendModeUserGuard(settingService)) { authenticated.POST("/generate", h.SoraClient.Generate) authenticated.GET("/generations", h.SoraClient.ListGenerations) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index d0ed2489..c3b82742 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -3,6 +3,7 @@ package routes import ( "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -12,9 +13,11 @@ func RegisterUserRoutes( v1 *gin.RouterGroup, h *handler.Handlers, jwtAuth middleware.JWTAuthMiddleware, + settingService *service.SettingService, ) { authenticated := v1.Group("") authenticated.Use(gin.HandlerFunc(jwtAuth)) + authenticated.Use(middleware.BackendModeUserGuard(settingService)) { // 用户接口 user := authenticated.Group("/user") diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 9a871c10..578d1da3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,7 @@ package service import ( "encoding/json" + "errors" "hash/fnv" "reflect" "sort" @@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool { // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // 如果未配置 mapping,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { + mappedModel, _ := a.ResolveMappedModel(requestedModel) + return mappedModel +} + +// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。 +// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。 +func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) { mapping := a.GetModelMapping() if len(mapping) == 0 { - return requestedModel + return requestedModel, false } // 精确匹配优先 if mappedModel, exists := mapping[requestedModel]; exists { - return mappedModel + return mappedModel, true } // 通配符匹配(最长优先) - return matchWildcardMapping(mapping, requestedModel) + return matchWildcardMappingResult(mapping, requestedModel) } func (a *Account) GetBaseURL() string { @@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool { return matchAntigravityWildcard(pattern, str) } -// matchWildcardMapping 通配符映射匹配(最长优先) -// 如果没有匹配,返回原始字符串 -func matchWildcardMapping(mapping map[string]string, requestedModel string) string { +func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) { // 收集所有匹配的 pattern,按长度降序排序(最长优先) type patternMatch struct { pattern string @@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri } if len(matches) == 0 { - return requestedModel // 无匹配,返回原始模型名 + return requestedModel, false // 无匹配,返回原始模型名 } // 按 pattern 长度降序排序 @@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri return matches[i].pattern < matches[j].pattern }) - return matches[0].target + return matches[0].target, true } func (a *Account) IsCustomErrorCodesEnabled() bool { @@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool { // IsPoolMode 检查 API Key 账号是否启用池模式。 // 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。 func (a *Account) IsPoolMode() bool { - if a.Type != AccountTypeAPIKey || a.Credentials == nil { + if !a.IsAPIKeyOrBedrock() || a.Credentials == nil { return false } if v, ok := a.Credentials["pool_mode"]; ok { @@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool { } func (a *Account) IsBedrock() bool { - return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey) + return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock } func (a *Account) IsBedrockAPIKey() bool { - return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey + return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey" +} + +// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性 +func (a *Account) IsAPIKeyOrBedrock() bool { + return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock } func (a *Account) IsOpenAI() bool { @@ -1269,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time { return time.Time{} } +// getExtraString 从 Extra 中读取指定 key 的字符串值 +func (a *Account) getExtraString(key string) string { + if a.Extra == nil { + return "" + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// getExtraInt 从 Extra 中读取指定 key 的 int 值 +func (a *Account) getExtraInt(key string) int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra[key]; ok { + return int(parseExtraFloat64(v)) + } + return 0 +} + +// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaDailyResetMode() string { + if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaDailyResetHour() int { + return a.getExtraInt("quota_daily_reset_hour") +} + +// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed" +func (a *Account) GetQuotaWeeklyResetMode() string { + if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" { + return "fixed" + } + return "rolling" +} + +// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一) +func (a *Account) GetQuotaWeeklyResetDay() int { + if a.Extra == nil { + return 1 + } + if _, ok := a.Extra["quota_weekly_reset_day"]; !ok { + return 1 + } + return a.getExtraInt("quota_weekly_reset_day") +} + +// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0 +func (a *Account) GetQuotaWeeklyResetHour() int { + return a.getExtraInt("quota_weekly_reset_hour") +} + +// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC" +func (a *Account) GetQuotaResetTimezone() string { + if tz := a.getExtraString("quota_reset_timezone"); tz != "" { + return tz + } + return "UTC" +} + +// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 +func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if !after.Before(today) { + return today.AddDate(0, 0, 1) + } + return today +} + +// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点 +func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + if now.Before(today) { + return today.AddDate(0, 0, -1) + } + return today +} + +// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点 +// day: 0=Sunday, 1=Monday, ..., 6=Saturday +func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time { + t := after.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysForward := (day - currentDay + 7) % 7 + if daysForward == 0 && !after.Before(todayReset) { + daysForward = 7 + } + return todayReset.AddDate(0, 0, daysForward) +} + +// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点 +func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time { + t := now.In(tz) + todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz) + currentDay := int(todayReset.Weekday()) + + daysBack := (currentDay - day + 7) % 7 + if daysBack == 0 && now.Before(todayReset) { + daysBack = 7 + } + return todayReset.AddDate(0, 0, -daysBack) +} + +// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期 +func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期 +func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool { + if periodStart.IsZero() { + return true + } + tz, err := time.LoadLocation(a.GetQuotaResetTimezone()) + if err != nil { + tz = time.UTC + } + lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now()) + return periodStart.Before(lastReset) +} + +// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at +// 在保存账号配置时调用 +func ComputeQuotaResetAt(extra map[string]any) { + now := time.Now() + tzName, _ := extra["quota_reset_timezone"].(string) + if tzName == "" { + tzName = "UTC" + } + tz, err := time.LoadLocation(tzName) + if err != nil { + tz = time.UTC + } + + // 日配额固定重置时间 + if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" { + hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedDailyReset(hour, tz, now) + extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_daily_reset_at") + } + + // 周配额固定重置时间 + if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" { + day := 1 // 默认周一 + if d, ok := extra["quota_weekly_reset_day"]; ok { + day = int(parseExtraFloat64(d)) + } + if day < 0 || day > 6 { + day = 1 + } + hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"])) + if hour < 0 || hour > 23 { + hour = 0 + } + resetAt := nextFixedWeeklyReset(day, hour, tz, now) + extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339) + } else { + delete(extra, "quota_weekly_reset_at") + } +} + +// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性 +func ValidateQuotaResetConfig(extra map[string]any) error { + if extra == nil { + return nil + } + // 校验时区 + if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" { + if _, err := time.LoadLocation(tz); err != nil { + return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name") + } + } + // 日配额重置模式 + if mode, ok := extra["quota_daily_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'") + } + } + // 日配额重置小时 + if v, ok := extra["quota_daily_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_daily_reset_hour must be between 0 and 23") + } + } + // 周配额重置模式 + if mode, ok := extra["quota_weekly_reset_mode"].(string); ok { + if mode != "rolling" && mode != "fixed" { + return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'") + } + } + // 周配额重置星期几 + if v, ok := extra["quota_weekly_reset_day"]; ok { + day := int(parseExtraFloat64(v)) + if day < 0 || day > 6 { + return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)") + } + } + // 周配额重置小时 + if v, ok := extra["quota_weekly_reset_hour"]; ok { + hour := int(parseExtraFloat64(v)) + if hour < 0 || hour > 23 { + return errors.New("quota_weekly_reset_hour must be between 0 and 23") + } + } + return nil +} + // HasAnyQuotaLimit 检查是否配置了任一维度的配额限制 func (a *Account) HasAnyQuotaLimit() bool { return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0 @@ -1291,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool { // 日额度(周期过期视为未超限,下次 increment 会重置) if limit := a.GetQuotaDailyLimit(); limit > 0 { start := a.getExtraTime("quota_daily_start") - if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit { + var expired bool + if a.GetQuotaDailyResetMode() == "fixed" { + expired = a.isFixedDailyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 24*time.Hour) + } + if !expired && a.GetQuotaDailyUsed() >= limit { return true } } // 周额度 if limit := a.GetQuotaWeeklyLimit(); limit > 0 { start := a.getExtraTime("quota_weekly_start") - if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit { + var expired bool + if a.GetQuotaWeeklyResetMode() == "fixed" { + expired = a.isFixedWeeklyPeriodExpired(start) + } else { + expired = isPeriodExpired(start, 7*24*time.Hour) + } + if !expired && a.GetQuotaWeeklyUsed() >= limit { return true } } diff --git a/backend/internal/service/account_quota_reset_test.go b/backend/internal/service/account_quota_reset_test.go new file mode 100644 index 00000000..45a4bad6 --- /dev/null +++ b/backend/internal/service/account_quota_reset_test.go @@ -0,0 +1,516 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// nextFixedDailyReset +// --------------------------------------------------------------------------- + +func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) { + tz := time.UTC + // 2026-03-14 06:00 UTC, reset hour = 9 + after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_AtResetHour(t *testing.T) { + tz := time.UTC + // Exactly at reset hour → should return tomorrow + after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_AfterResetHour(t *testing.T) { + tz := time.UTC + // After reset hour → should return tomorrow + after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz) + got := nextFixedDailyReset(9, tz, after) + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_MidnightReset(t *testing.T) { + tz := time.UTC + // Reset at hour 0 (midnight), currently 23:59 + after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz) + got := nextFixedDailyReset(0, tz, after) + want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) { + tz, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + + // 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST) + after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC) + got := nextFixedDailyReset(9, tz, after) + // Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC + want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// lastFixedDailyReset +// --------------------------------------------------------------------------- + +func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // Before today's 9:00 → yesterday 9:00 + want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedDailyReset_AtResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // At exactly 9:00 → today 9:00 + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedDailyReset_AfterResetHour(t *testing.T) { + tz := time.UTC + now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz) + got := lastFixedDailyReset(9, tz, now) + // After 9:00 → today 9:00 + want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// nextFixedWeeklyReset +// --------------------------------------------------------------------------- + +func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) { + tz := time.UTC + // 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9 + after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday = 2026-03-16 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00 + after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Today at 9:00 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00 + after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday at 9:00 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00 + after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday at 9:00 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) { + tz := time.UTC + // 2026-03-18 is Wednesday (day=3), target = Monday (day=1) + after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(1, 9, tz, after) + // Next Monday = 2026-03-23 + want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestNextFixedWeeklyReset_Sunday(t *testing.T) { + tz := time.UTC + // 2026-03-14 is Saturday (day=6), target = Sunday (day=0) + after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz) + got := nextFixedWeeklyReset(0, 0, tz, after) + // Next Sunday = 2026-03-15 + want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// lastFixedWeeklyReset +// --------------------------------------------------------------------------- + +func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00 + now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Today at 9:00 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) { + tz := time.UTC + // 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00 + now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Last Monday at 9:00 = 2026-03-09 + want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) { + tz := time.UTC + // 2026-03-18 is Wednesday (day=3), target = Monday (day=1) + now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz) + got := lastFixedWeeklyReset(1, 9, tz, now) + // Last Monday = 2026-03-16 + want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz) + assert.Equal(t, want, got) +} + +// --------------------------------------------------------------------------- +// isFixedDailyPeriodExpired +// --------------------------------------------------------------------------- + +func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + assert.True(t, a.isFixedDailyPeriodExpired(time.Time{})) +} + +func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started after the most recent reset → not expired + // (This test uses a time very close to "now", which is after the last reset) + periodStart := time.Now().Add(-1 * time.Minute) + assert.False(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 3 days ago → definitely expired + periodStart := time.Now().Add(-72 * time.Hour) + assert.True(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "Invalid/Timezone", + }} + // Invalid timezone falls back to UTC + periodStart := time.Now().Add(-72 * time.Hour) + assert.True(t, a.isFixedDailyPeriodExpired(periodStart)) +} + +// --------------------------------------------------------------------------- +// isFixedWeeklyPeriodExpired +// --------------------------------------------------------------------------- + +func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{})) +} + +func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 1 minute ago → not expired + periodStart := time.Now().Add(-1 * time.Minute) + assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart)) +} + +func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) { + a := &Account{Extra: map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + }} + // Period started 10 days ago → definitely expired + periodStart := time.Now().Add(-240 * time.Hour) + assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart)) +} + +// --------------------------------------------------------------------------- +// ValidateQuotaResetConfig +// --------------------------------------------------------------------------- + +func TestValidateQuotaResetConfig_NilExtra(t *testing.T) { + assert.NoError(t, ValidateQuotaResetConfig(nil)) +} + +func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) { + assert.NoError(t, ValidateQuotaResetConfig(map[string]any{})) +} + +func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "Asia/Shanghai", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) +} + +func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) +} + +func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) { + extra := map[string]any{ + "quota_reset_timezone": "Not/A/Timezone", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_reset_timezone") +} + +func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "invalid", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_mode") +} + +func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_hour": float64(24), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_hour") +} + +func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_hour": float64(-1), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_daily_reset_hour") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_mode": "unknown", + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_mode") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_day": float64(7), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_day") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_day": float64(-1), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_day") +} + +func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_hour": float64(25), + } + err := ValidateQuotaResetConfig(extra) + require.Error(t, err) + assert.Contains(t, err.Error(), "quota_weekly_reset_hour") +} + +func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) { + // All boundary values should be valid + extra := map[string]any{ + "quota_daily_reset_hour": float64(23), + "quota_weekly_reset_day": float64(0), // Sunday + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "UTC", + } + assert.NoError(t, ValidateQuotaResetConfig(extra)) + + extra2 := map[string]any{ + "quota_daily_reset_hour": float64(0), + "quota_weekly_reset_day": float64(6), // Saturday + "quota_weekly_reset_hour": float64(23), + } + assert.NoError(t, ValidateQuotaResetConfig(extra2)) +} + +// --------------------------------------------------------------------------- +// ComputeQuotaResetAt +// --------------------------------------------------------------------------- + +func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + } + ComputeQuotaResetAt(extra) + _, hasDailyResetAt := extra["quota_daily_reset_at"] + _, hasWeeklyResetAt := extra["quota_weekly_reset_at"] + assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at") + assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at") +} + +func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "rolling", + "quota_weekly_reset_mode": "rolling", + "quota_daily_reset_at": "2026-03-14T09:00:00Z", + "quota_weekly_reset_at": "2026-03-16T09:00:00Z", + } + ComputeQuotaResetAt(extra) + _, hasDailyResetAt := extra["quota_daily_reset_at"] + _, hasWeeklyResetAt := extra["quota_weekly_reset_at"] + assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at") + assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at") +} + +func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok, "quota_daily_reset_at should be set") + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Reset time should be in the future + assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future") + // Reset hour should be 9 UTC + assert.Equal(t, 9, resetAt.UTC().Hour()) +} + +func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) { + extra := map[string]any{ + "quota_weekly_reset_mode": "fixed", + "quota_weekly_reset_day": float64(1), // Monday + "quota_weekly_reset_hour": float64(0), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_weekly_reset_at"].(string) + require.True(t, ok, "quota_weekly_reset_at should be set") + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Reset time should be in the future + assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future") + // Reset day should be Monday + assert.Equal(t, time.Monday, resetAt.UTC().Weekday()) +} + +func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) { + tz, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(9), + "quota_reset_timezone": "Asia/Shanghai", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // In Shanghai timezone, the hour should be 9 + assert.Equal(t, 9, resetAt.In(tz).Hour()) +} + +func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(12), + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Default timezone is UTC + assert.Equal(t, 12, resetAt.UTC().Hour()) +} + +func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) { + extra := map[string]any{ + "quota_daily_reset_mode": "fixed", + "quota_daily_reset_hour": float64(99), + "quota_reset_timezone": "UTC", + } + ComputeQuotaResetAt(extra) + resetAtStr, ok := extra["quota_daily_reset_at"].(string) + require.True(t, ok) + + resetAt, err := time.Parse(time.RFC3339, resetAtStr) + require.NoError(t, err) + // Invalid hour → clamped to 0 + assert.Equal(t, 0, resetAt.UTC().Hour()) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 3dd931be..d41e890a 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log" + "log/slog" "math/rand/v2" "net/http" "strings" @@ -100,6 +101,7 @@ type antigravityUsageCache struct { const ( apiCacheTTL = 3 * time.Minute apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟 + antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误) apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 windowStatsCacheTTL = 1 * time.Minute openAIProbeCacheTTL = 10 * time.Minute @@ -108,11 +110,12 @@ const ( // UsageCache 封装账户使用量相关的缓存 type UsageCache struct { - apiCache sync.Map // accountID -> *apiUsageCache - windowStatsCache sync.Map // accountID -> *windowStatsCache - antigravityCache sync.Map // accountID -> *antigravityUsageCache - apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存 - openAIProbeCache sync.Map // accountID -> time.Time + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic) + antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存 + openAIProbeCache sync.Map // accountID -> time.Time } // NewUsageCache 创建 UsageCache 实例 @@ -149,6 +152,18 @@ type AntigravityModelQuota struct { ResetTime string `json:"reset_time"` // 重置时间 ISO8601 } +// AntigravityModelDetail Antigravity 单个模型的详细能力信息 +type AntigravityModelDetail struct { + DisplayName string `json:"display_name,omitempty"` + SupportsImages *bool `json:"supports_images,omitempty"` + SupportsThinking *bool `json:"supports_thinking,omitempty"` + ThinkingBudget *int `json:"thinking_budget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"` +} + // UsageInfo 账号使用量信息 type UsageInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 @@ -164,6 +179,33 @@ type UsageInfo struct { // Antigravity 多模型配额 AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` + + // Antigravity 账号级信息 + SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN + SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称 + + // Antigravity 模型详细能力信息(与 antigravity_quota 同 key) + AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"` + + // Antigravity 废弃模型转发规则 (old_model_id -> new_model_id) + ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"` + + // Antigravity 账号是否被上游禁止 (HTTP 403) + IsForbidden bool `json:"is_forbidden,omitempty"` + ForbiddenReason string `json:"forbidden_reason,omitempty"` + ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden" + ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接 + + // 状态标记(从 ForbiddenType / HTTP 错误码推导) + NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation) + IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation) + NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401) + + // 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error + ErrorCode string `json:"error_code,omitempty"` + + // 获取 usage 时的错误信息(降级返回,而非 500) + Error string `json:"error,omitempty"` } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -648,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account * return &UsageInfo{UpdatedAt: &now}, nil } - // 1. 检查缓存(10 分钟) + // 1. 检查缓存 if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { - if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - // 重新计算 RemainingSeconds - usage := cache.usageInfo - if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { - usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + } + return usage, nil } - return usage, nil } } - // 2. 获取代理 URL - proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + // 2. singleflight 防止并发击穿 + flightKey := fmt.Sprintf("ag-usage:%d", account.ID) + result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) { + // 再次检查缓存(等待期间可能已被填充) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + // 重新计算 RemainingSeconds,避免返回过时的剩余秒数 + recalcAntigravityRemainingSeconds(usage) + return usage, nil + } + } + } - // 3. 调用 API 获取额度 - result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) - if err != nil { - return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) - } + // 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败 + fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer fetchCancel() - // 4. 缓存结果 - s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ - usageInfo: result.UsageInfo, - timestamp: time.Now(), + proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account) + fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL) + if err != nil { + degraded := buildAntigravityDegradedUsage(err) + enrichUsageWithAccountError(degraded, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: degraded, + timestamp: time.Now(), + }) + return degraded, nil + } + + enrichUsageWithAccountError(fetchResult.UsageInfo, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: fetchResult.UsageInfo, + timestamp: time.Now(), + }) + return fetchResult.UsageInfo, nil }) - return result.UsageInfo, nil + if flightErr != nil { + return nil, flightErr + } + usage, ok := result.(*UsageInfo) + if !ok || usage == nil { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + return usage, nil +} + +// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds +// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 +func recalcAntigravityRemainingSeconds(info *UsageInfo) { + if info == nil { + return + } + if info.FiveHour != nil && info.FiveHour.ResetsAt != nil { + remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds()) + if remaining < 0 { + remaining = 0 + } + info.FiveHour.RemainingSeconds = remaining + } +} + +// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL +// 403 forbidden 状态稳定,缓存与成功相同(3 分钟); +// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。 +func antigravityCacheTTL(info *UsageInfo) time.Duration { + if info == nil { + return antigravityErrorTTL + } + if info.IsForbidden { + return apiCacheTTL // 封号/验证状态不会很快变 + } + if info.ErrorCode != "" || info.Error != "" { + return antigravityErrorTTL + } + return apiCacheTTL +} + +// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo +func buildAntigravityDegradedUsage(err error) *UsageInfo { + now := time.Now() + errMsg := fmt.Sprintf("usage API error: %v", err) + slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err) + + info := &UsageInfo{ + UpdatedAt: &now, + Error: errMsg, + } + + // 从错误信息推断 error_code 和状态标记 + // 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..." + errStr := err.Error() + switch { + case strings.Contains(errStr, "HTTP 401") || + strings.Contains(errStr, "UNAUTHENTICATED") || + strings.Contains(errStr, "invalid_grant"): + info.ErrorCode = errorCodeUnauthenticated + info.NeedsReauth = true + case strings.Contains(errStr, "HTTP 429"): + info.ErrorCode = errorCodeRateLimited + default: + info.ErrorCode = errorCodeNetworkError + } + + return info +} + +// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo +// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error, +// +// 需要在正常 usage 数据上附加 forbidden/validation 信息。 +// +// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401, +// +// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。 +func enrichUsageWithAccountError(info *UsageInfo, account *Account) { + if info == nil || account == nil || account.Status != StatusError { + return + } + msg := strings.ToLower(account.ErrorMessage) + if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") && + !strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") { + return + } + fbType := classifyForbiddenType(account.ErrorMessage) + info.IsForbidden = true + info.ForbiddenType = fbType + info.ForbiddenReason = account.ErrorMessage + info.NeedsVerify = fbType == forbiddenTypeValidation + info.IsBanned = fbType == forbiddenTypeViolation + info.ValidationURL = extractValidationURL(account.ErrorMessage) + info.ErrorCode = errorCodeForbidden + info.NeedsReauth = false } // addWindowStats 为 usage 数据添加窗口期统计 diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 7782f948..0d7ffffa 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) { } } -func TestMatchWildcardMapping(t *testing.T) { +func TestMatchWildcardMappingResult(t *testing.T) { tests := []struct { name string mapping map[string]string requestedModel string expected string + matched bool }{ // 精确匹配优先于通配符 { @@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5-exact", + matched: true, }, // 最长通配符优先 @@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-series", + matched: true, }, // 单个通配符 @@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-opus-4-5", expected: "claude-mapped", + matched: true, }, // 无匹配返回原始模型 @@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash", expected: "gemini-3-flash", + matched: false, }, // 空映射返回原始模型 @@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) { mapping: map[string]string{}, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", + matched: false, }, // Gemini 模型映射 @@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash-preview", expected: "gemini-3-pro-high", + matched: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := matchWildcardMapping(tt.mapping, tt.requestedModel) - if result != tt.expected { - t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) + result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel) + if result != tt.expected || matched != tt.matched { + t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched) } }) } @@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) { } } +func TestAccountResolveMappedModel(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expectedModel string + expectedMatch bool + }{ + { + name: "no mapping reports unmatched", + credentials: nil, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + { + name: "exact passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "wildcard passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "missing mapping reports unmatched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Credentials: tt.credentials, + } + mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) + if mappedModel != tt.expectedModel || matched != tt.expectedMatch { + t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch) + } + }) + } +} + func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { account := &Account{ Platform: PlatformAntigravity, diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 10d67518..86824b6f 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1462,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Status: StatusActive, Schedulable: true, } + // 预计算固定时间重置的下次重置时间 + if account.Extra != nil { + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) + } if input.ExpiresAt != nil && *input.ExpiresAt > 0 { expiresAt := time.Unix(*input.ExpiresAt, 0) account.ExpiresAt = &expiresAt @@ -1535,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } } account.Extra = input.Extra + // 校验并预计算固定时间重置的下次重置时间 + if err := ValidateQuotaResetConfig(account.Extra); err != nil { + return nil, err + } + ComputeQuotaResetAt(account.Extra) } if input.ProxyID != nil { // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index e950ec1d..f8990b1a 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,12 +2,29 @@ package service import ( "context" + "encoding/json" + "errors" "fmt" + "log/slog" + "regexp" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +const ( + forbiddenTypeValidation = "validation" + forbiddenTypeViolation = "violation" + forbiddenTypeForbidden = "forbidden" + + // 机器可读的错误码 + errorCodeForbidden = "forbidden" + errorCodeUnauthenticated = "unauthenticated" + errorCodeRateLimited = "rate_limited" + errorCodeNetworkError = "network_error" +) + // AntigravityQuotaFetcher 从 Antigravity API 获取额度 type AntigravityQuotaFetcher struct { proxyRepo ProxyRepository @@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) if err != nil { + // 403 Forbidden: 不报错,返回 is_forbidden 标记 + var forbiddenErr *antigravity.ForbiddenError + if errors.As(err, &forbiddenErr) { + now := time.Now() + fbType := classifyForbiddenType(forbiddenErr.Body) + return &QuotaResult{ + UsageInfo: &UsageInfo{ + UpdatedAt: &now, + IsForbidden: true, + ForbiddenReason: forbiddenErr.Body, + ForbiddenType: fbType, + ValidationURL: extractValidationURL(forbiddenErr.Body), + NeedsVerify: fbType == forbiddenTypeValidation, + IsBanned: fbType == forbiddenTypeViolation, + ErrorCode: errorCodeForbidden, + }, + }, nil + } return nil, err } + // 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程) + tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken) + // 转换为 UsageInfo - usageInfo := f.buildUsageInfo(modelsResp) + usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized) return &QuotaResult{ UsageInfo: usageInfo, @@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou }, nil } -// buildUsageInfo 将 API 响应转换为 UsageInfo -func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { - now := time.Now() - info := &UsageInfo{ - UpdatedAt: &now, - AntigravityQuota: make(map[string]*AntigravityModelQuota), +// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串 +func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) { + loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + if err != nil { + slog.Warn("failed to fetch subscription tier", "error", err) + return "", "" + } + if loadResp == nil { + return "", "" } - // 遍历所有模型,填充 AntigravityQuota + raw = loadResp.GetTier() // 已有方法:paidTier > currentTier + normalized = normalizeTier(raw) + return raw, normalized +} + +// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN +func normalizeTier(raw string) string { + if raw == "" { + return "" + } + lower := strings.ToLower(raw) + switch { + case strings.Contains(lower, "ultra"): + return "ULTRA" + case strings.Contains(lower, "pro"): + return "PRO" + case strings.Contains(lower, "free"): + return "FREE" + default: + return "UNKNOWN" + } +} + +// buildUsageInfo 将 API 响应转换为 UsageInfo +func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo { + now := time.Now() + info := &UsageInfo{ + UpdatedAt: &now, + AntigravityQuota: make(map[string]*AntigravityModelQuota), + AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail), + SubscriptionTier: tierNormalized, + SubscriptionTierRaw: tierRaw, + } + + // 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails for modelName, modelInfo := range modelsResp.Models { if modelInfo.QuotaInfo == nil { continue @@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv Utilization: utilization, ResetTime: modelInfo.QuotaInfo.ResetTime, } + + // 填充模型详细能力信息 + detail := &AntigravityModelDetail{ + DisplayName: modelInfo.DisplayName, + SupportsImages: modelInfo.SupportsImages, + SupportsThinking: modelInfo.SupportsThinking, + ThinkingBudget: modelInfo.ThinkingBudget, + Recommended: modelInfo.Recommended, + MaxTokens: modelInfo.MaxTokens, + MaxOutputTokens: modelInfo.MaxOutputTokens, + SupportedMimeTypes: modelInfo.SupportedMimeTypes, + } + info.AntigravityQuotaDetails[modelName] = detail + } + + // 废弃模型转发规则 + if len(modelsResp.DeprecatedModelIDs) > 0 { + info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs)) + for oldID, deprecated := range modelsResp.DeprecatedModelIDs { + info.ModelForwardingRules[oldID] = deprecated.NewModelID + } } // 同时设置 FiveHour 用于兼容展示(取主要模型) @@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco } return proxy.URL() } + +// classifyForbiddenType 根据 403 响应体判断禁止类型 +func classifyForbiddenType(body string) string { + lower := strings.ToLower(body) + switch { + case strings.Contains(lower, "validation_required") || + strings.Contains(lower, "verify your account") || + strings.Contains(lower, "validation_url"): + return forbiddenTypeValidation + case strings.Contains(lower, "terms of service") || + strings.Contains(lower, "violation"): + return forbiddenTypeViolation + default: + return forbiddenTypeForbidden + } +} + +// urlPattern 用于从 403 响应体中提取 URL(降级方案) +var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`) + +// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接 +func extractValidationURL(body string) string { + // 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url + var parsed struct { + Error struct { + Details []struct { + Metadata map[string]string `json:"metadata"` + } `json:"details"` + } `json:"error"` + } + if json.Unmarshal([]byte(body), &parsed) == nil { + for _, detail := range parsed.Error.Details { + if u := detail.Metadata["validation_url"]; u != "" { + return u + } + if u := detail.Metadata["appeal_url"]; u != "" { + return u + } + } + } + + // 2. 降级:正则匹配 URL + lower := strings.ToLower(body) + if !strings.Contains(lower, "validation") && + !strings.Contains(lower, "verify") && + !strings.Contains(lower, "appeal") { + return "" + } + // 先解码常见转义再匹配 + normalized := strings.ReplaceAll(body, `\u0026`, "&") + if m := urlPattern.FindString(normalized); m != "" { + return m + } + return "" +} diff --git a/backend/internal/service/antigravity_quota_fetcher_test.go b/backend/internal/service/antigravity_quota_fetcher_test.go new file mode 100644 index 00000000..5ead8e60 --- /dev/null +++ b/backend/internal/service/antigravity_quota_fetcher_test.go @@ -0,0 +1,497 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// --------------------------------------------------------------------------- +// normalizeTier +// --------------------------------------------------------------------------- + +func TestNormalizeTier(t *testing.T) { + tests := []struct { + name string + raw string + expected string + }{ + {name: "empty string", raw: "", expected: ""}, + {name: "free-tier", raw: "free-tier", expected: "FREE"}, + {name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"}, + {name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"}, + {name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"}, + {name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"}, + {name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"}, + {name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"}, + {name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeTier(tt.raw) + require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw) + }) + } +} + +// --------------------------------------------------------------------------- +// buildUsageInfo +// --------------------------------------------------------------------------- + +func aqfBoolPtr(v bool) *bool { return &v } +func aqfIntPtr(v int) *int { return &v } + +func TestBuildUsageInfo_BasicModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.75, + ResetTime: "2026-03-08T12:00:00Z", + }, + DisplayName: "Claude Sonnet 4", + SupportsImages: aqfBoolPtr(true), + SupportsThinking: aqfBoolPtr(false), + ThinkingBudget: aqfIntPtr(0), + Recommended: aqfBoolPtr(true), + MaxTokens: aqfIntPtr(200000), + MaxOutputTokens: aqfIntPtr(16384), + SupportedMimeTypes: map[string]bool{ + "image/png": true, + "image/jpeg": true, + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "2026-03-08T15:00:00Z", + }, + DisplayName: "Gemini 2.5 Pro", + MaxTokens: aqfIntPtr(1000000), + MaxOutputTokens: aqfIntPtr(65536), + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO") + + // 基本字段 + require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set") + require.Equal(t, "PRO", info.SubscriptionTier) + require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw) + + // AntigravityQuota + require.Len(t, info.AntigravityQuota, 2) + + sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetQuota) + require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25 + require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime) + + geminiQuota := info.AntigravityQuota["gemini-2.5-pro"] + require.NotNil(t, geminiQuota) + require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50 + require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime) + + // AntigravityQuotaDetails + require.Len(t, info.AntigravityQuotaDetails, 2) + + sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetDetail) + require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages) + require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended) + require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens) + require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens) + require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes) + + geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"] + require.NotNil(t, geminiDetail) + require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName) + require.Nil(t, geminiDetail.SupportsImages) + require.Nil(t, geminiDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens) + require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens) +} + +func TestBuildUsageInfo_DeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, + }, + }, + }, + DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{ + "claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"}, + "claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"}, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Len(t, info.ModelForwardingRules, 2) + require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"]) + require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"]) +} + +func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9}, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models") +} + +func TestBuildUsageInfo_EmptyModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{}, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info) + require.NotNil(t, info.AntigravityQuota) + require.Empty(t, info.AntigravityQuota) + require.NotNil(t, info.AntigravityQuotaDetails) + require.Empty(t, info.AntigravityQuotaDetails) + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "model-without-quota": { + DisplayName: "No Quota Model", + // QuotaInfo is nil + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info) + require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped") + require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too") +} + +func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"] + // When the first priority model exists, it should be used for FiveHour + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.40, + ResetTime: "2026-03-08T18:00:00Z", + }, + }, + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.80, + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists") + // claude-sonnet-4-20250514 is first in priority list, so it should be used + expectedUtilization := (1.0 - 0.80) * 100 // 20 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) + require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime") +} + +func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514 + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.60, + ResetTime: "2026-03-08T14:00:00Z", + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.60) * 100 // 40 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only gemini-2.5-pro exists (third in priority list) + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + "other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.90, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.30) * 100 // 70 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // None of the priority models exist + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "", // empty reset time + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty") + require.Equal(t, 0, info.FiveHour.RemainingSeconds) +} + +func TestBuildUsageInfo_FullUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.0, // fully used + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 100, quota.Utilization) +} + +func TestBuildUsageInfo_ZeroUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, // fully available + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 0, quota.Utilization) +} + +func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) { + // 模拟 FetchQuota 遇到 403 时的行为: + // FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true + forbiddenErr := &antigravity.ForbiddenError{ + StatusCode: 403, + Body: "Access denied", + } + + // 验证 ForbiddenError 满足 errors.As + var target *antigravity.ForbiddenError + require.True(t, errors.As(forbiddenErr, &target)) + require.Equal(t, 403, target.StatusCode) + require.Equal(t, "Access denied", target.Body) + require.Contains(t, forbiddenErr.Error(), "403") +} + +// --------------------------------------------------------------------------- +// classifyForbiddenType +// --------------------------------------------------------------------------- + +func TestClassifyForbiddenType(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "VALIDATION_REQUIRED keyword", + body: `{"error":{"message":"VALIDATION_REQUIRED"}}`, + expected: "validation", + }, + { + name: "verify your account", + body: `Please verify your account to continue`, + expected: "validation", + }, + { + name: "contains validation_url field", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`, + expected: "validation", + }, + { + name: "terms of service violation", + body: `Your account has been suspended for Terms of Service violation`, + expected: "violation", + }, + { + name: "violation keyword", + body: `Account suspended due to policy violation`, + expected: "violation", + }, + { + name: "generic 403", + body: `Access denied`, + expected: "forbidden", + }, + { + name: "empty body", + body: "", + expected: "forbidden", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyForbiddenType(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// extractValidationURL +// --------------------------------------------------------------------------- + +func TestExtractValidationURL(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "structured validation_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`, + expected: "https://accounts.google.com/verify?token=abc", + }, + { + name: "structured appeal_url", + body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`, + expected: "https://support.google.com/appeal/123", + }, + { + name: "validation_url takes priority over appeal_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`, + expected: "https://v.com", + }, + { + name: "fallback regex with verify keyword", + body: `Please verify your account at https://accounts.google.com/verify`, + expected: "https://accounts.google.com/verify", + }, + { + name: "no URL in generic forbidden", + body: `Access denied`, + expected: "", + }, + { + name: "empty body", + body: "", + expected: "", + }, + { + name: "URL present but no validation keywords", + body: `Error at https://example.com/something`, + expected: "", + }, + { + name: "unicode escaped ampersand", + body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`, + expected: "https://accounts.google.com/verify?a=1&b=2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractValidationURL(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 28607e9f..6e524fb9 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -1087,6 +1087,12 @@ type TokenPair struct { ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) } +// TokenPairWithUser extends TokenPair with user role for backend mode checks +type TokenPairWithUser struct { + TokenPair + UserRole string +} + // GenerateTokenPair 生成Access Token和Refresh Token对 // familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { @@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami // RefreshTokenPair 使用Refresh Token刷新Token对 // 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 -func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { +func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, ErrRefreshTokenInvalid @@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) } // 生成新的Token对,保持同一个家族ID - return s.GenerateTokenPair(ctx, user, data.FamilyID) + pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID) + if err != nil { + return nil, err + } + return &TokenPairWithUser{ + TokenPair: *pair, + UserRole: user.Role, + }, nil } // RevokeRefreshToken 撤销单个Refresh Token diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ad64b467..c605f67a 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -29,12 +29,11 @@ const ( // Account type constants const ( - AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 - AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) - AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) - AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) + AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) ) // Redeem type constants @@ -221,6 +220,9 @@ const ( // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403) SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" + + // SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录 + SettingKeyBackendModeEnabled = "backend_mode_enabled" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index dd9850bd..297a954c 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) { expected: ErrorPolicyTempUnscheduled, }, { - name: "temp_unschedulable_401_second_hit_upgrades_to_none", + // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制), + // second hit 仍然返回 TempUnscheduled。 + name: "temp_unschedulable_401_second_hit_antigravity_stays_temp", account: &Account{ ID: 15, Type: AccountTypeOAuth, @@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) { }, statusCode: 401, body: []byte(`unauthorized`), - expected: ErrorPolicyNone, + expected: ErrorPolicyTempUnscheduled, }, { name: "temp_unschedulable_body_miss_returns_none", diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 1756b1f4..55b11ec2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2173,10 +2173,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts [] return context.WithValue(ctx, windowCostPrefetchContextKey, costs) } -// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内 -// 仅适用于配置了 quota_limit 的 apikey 类型账号 +// isAccountSchedulableForQuota 检查账号是否在配额限制内 +// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号 func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { - if account.Type != AccountTypeAPIKey { + if !account.IsAPIKeyOrBedrock() { return true } return !account.IsQuotaExceeded() @@ -3532,9 +3532,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( } return apiKey, "apikey", nil case AccountTypeBedrock: - return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token - case AccountTypeBedrockAPIKey: - return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理 + return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -5186,7 +5184,7 @@ func (s *GatewayService) forwardBedrock( if account.IsBedrockAPIKey() { bedrockAPIKey = account.GetCredential("api_key") if bedrockAPIKey == "" { - return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials") + return nil, fmt.Errorf("api_key not found in bedrock credentials") } } else { signer, err = NewBedrockSignerFromAccount(account) @@ -5375,8 +5373,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors( Message: extractUpstreamErrorMessage(respBody), }) return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } return s.handleRetryExhaustedError(ctx, resp, c, account) @@ -5398,8 +5397,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors( Message: extractUpstreamErrorMessage(respBody), }) return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } @@ -5808,9 +5808,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri return betaPolicyResult{} } isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() var result betaPolicyResult for _, rule := range settings.Rules { - if !betaPolicyScopeMatches(rule.Scope, isOAuth) { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } switch rule.Action { @@ -5870,14 +5871,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. -func betaPolicyScopeMatches(scope string, isOAuth bool) bool { +func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { switch scope { case BetaPolicyScopeAll: return true case BetaPolicyScopeOAuth: return isOAuth case BetaPolicyScopeAPIKey: - return !isOAuth + return !isOAuth && !isBedrock + case BetaPolicyScopeBedrock: + return isBedrock default: return true // unknown scope → match all (fail-open) } @@ -5959,12 +5962,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke return nil } isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { if rule.Action != BetaPolicyActionBlock { continue } - if !betaPolicyScopeMatches(rule.Scope, isOAuth) { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { @@ -7199,7 +7203,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) - if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) @@ -7287,7 +7291,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { cmd.APIKeyRateLimitCost = p.Cost.ActualCost } - if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 8fffce1b..29f2b672 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } typ, _ := m["type"].(string) - // 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'" - fixIDPrefix := func(id string) string { + // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id; + // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 + fixCallIDPrefix := func(id string) string { if id == "" || strings.HasPrefix(id, "fc") { return id } @@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any { for key, value := range m { newItem[key] = value } - if id, ok := newItem["id"].(string); ok && id != "" { - newItem["id"] = fixIDPrefix(id) + if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") { + newItem["id"] = fixCallIDPrefix(id) } filtered = append(filtered, newItem) continue @@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } if callID != "" { - fixedCallID := fixIDPrefix(callID) + fixedCallID := fixCallIDPrefix(callID) if fixedCallID != callID { ensureCopy() newItem["call_id"] = fixedCallID @@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any { if !isCodexToolCallItemType(typ) { delete(newItem, "call_id") } - } else { - if id, ok := newItem["id"].(string); ok && id != "" { - fixedID := fixIDPrefix(id) - if fixedID != id { - ensureCopy() - newItem["id"] = fixedID - } - } } filtered = append(filtered, newItem) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index df012d7c..ae6f8555 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { first, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "item_reference", first["type"]) - require.Equal(t, "fc_ref1", first["id"]) + require.Equal(t, "ref1", first["id"]) // 校验 input[1] 为 map,确保后续字段断言安全。 second, ok := input[1].(map[string]any) require.True(t, ok) - require.Equal(t, "fc_o1", second["id"]) + require.Equal(t, "o1", second["id"]) + require.Equal(t, "fc1", second["call_id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"}, + map[string]any{"type": "item_reference", "id": "rs_123"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "msg_0", first["id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "rs_123", second["id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "item_reference", "id": "call_1"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "fc1", first["id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "fc1", second["call_id"]) } func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index f893eeb9..9529f6be 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 3. Model mapping - mappedModel := account.GetMappedModel(originalModel) - if mappedModel == originalModel && defaultMappedModel != "" { - mappedModel = defaultMappedModel - } + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) responsesReq.Model = mappedModel logger.L().Debug("openai chat_completions: model mapping applied", diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index e4a3d9c0..1e40ec6f 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := account.GetMappedModel(originalModel) - // 分组级降级:账号未映射时使用分组默认映射模型 - if mappedModel == originalModel && defaultMappedModel != "" { - mappedModel = defaultMappedModel - } + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) responsesReq.Model = mappedModel logger.L().Debug("openai messages: model mapping applied", diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go new file mode 100644 index 00000000..9bf3fba3 --- /dev/null +++ b/backend/internal/service/openai_model_mapping.go @@ -0,0 +1,19 @@ +package service + +// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible +// forwarding. Group-level default mapping only applies when the account itself +// did not match any explicit model_mapping rule. +func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { + if account == nil { + if defaultMappedModel != "" { + return defaultMappedModel + } + return requestedModel + } + + mappedModel, matched := account.ResolveMappedModel(requestedModel) + if !matched && defaultMappedModel != "" { + return defaultMappedModel + } + return mappedModel +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go new file mode 100644 index 00000000..7af3ecae --- /dev/null +++ b/backend/internal/service/openai_model_mapping_test.go @@ -0,0 +1,70 @@ +package service + +import "testing" + +func TestResolveOpenAIForwardModel(t *testing.T) { + tests := []struct { + name string + account *Account + requestedModel string + defaultMappedModel string + expectedModel string + }{ + { + name: "falls back to group default when account has no mapping", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-4o-mini", + }, + { + name: "preserves exact passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "preserves wildcard passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "uses account remap when explicit target differs", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel { + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel) + } + }) + } +} diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index 7514cc80..93815887 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -371,6 +371,8 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings { IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略 IgnoreContextCanceled: true, // Default to true - client disconnects are not errors IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue + DisplayOpenAITokenStats: false, + DisplayAlertEvents: true, AutoRefreshEnabled: false, AutoRefreshIntervalSec: 30, } @@ -438,7 +440,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe return nil, err } - cfg := &OpsAdvancedSettings{} + cfg := defaultOpsAdvancedSettings() if err := json.Unmarshal([]byte(raw), cfg); err != nil { return defaultCfg, nil } diff --git a/backend/internal/service/ops_settings_advanced_test.go b/backend/internal/service/ops_settings_advanced_test.go new file mode 100644 index 00000000..06cc545b --- /dev/null +++ b/backend/internal/service/ops_settings_advanced_test.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "encoding/json" + "testing" +) + +func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + cfg, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() error = %v", err) + } + if cfg.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = true, want false by default") + } + if !cfg.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = false, want true by default") + } + if repo.setCalls != 1 { + t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls) + } +} + +func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + cfg := defaultOpsAdvancedSettings() + cfg.DisplayOpenAITokenStats = true + cfg.DisplayAlertEvents = false + + updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg) + if err != nil { + t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err) + } + if !updated.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = false, want true") + } + if updated.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = true, want false") + } + + reloaded, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err) + } + if !reloaded.DisplayOpenAITokenStats { + t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true") + } + if reloaded.DisplayAlertEvents { + t.Fatalf("reloaded DisplayAlertEvents = true, want false") + } +} + +func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{settingRepo: repo} + + legacyCfg := map[string]any{ + "data_retention": map[string]any{ + "cleanup_enabled": false, + "cleanup_schedule": "0 2 * * *", + "error_log_retention_days": 30, + "minute_metrics_retention_days": 30, + "hourly_metrics_retention_days": 30, + }, + "aggregation": map[string]any{ + "aggregation_enabled": false, + }, + "ignore_count_tokens_errors": true, + "ignore_context_canceled": true, + "ignore_no_available_accounts": false, + "ignore_invalid_api_key_errors": false, + "auto_refresh_enabled": false, + "auto_refresh_interval_seconds": 30, + } + raw, err := json.Marshal(legacyCfg) + if err != nil { + t.Fatalf("marshal legacy config: %v", err) + } + repo.values[SettingKeyOpsAdvancedSettings] = string(raw) + + cfg, err := svc.GetOpsAdvancedSettings(context.Background()) + if err != nil { + t.Fatalf("GetOpsAdvancedSettings() error = %v", err) + } + if cfg.DisplayOpenAITokenStats { + t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill") + } + if !cfg.DisplayAlertEvents { + t.Fatalf("DisplayAlertEvents = false, want true default backfill") + } +} diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go index 8b5359e3..c8b9fcd1 100644 --- a/backend/internal/service/ops_settings_models.go +++ b/backend/internal/service/ops_settings_models.go @@ -98,6 +98,8 @@ type OpsAdvancedSettings struct { IgnoreContextCanceled bool `json:"ignore_context_canceled"` IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"` IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"` + DisplayOpenAITokenStats bool `json:"display_openai_token_stats"` + DisplayAlertEvents bool `json:"display_alert_events"` AutoRefreshEnabled bool `json:"auto_refresh_enabled"` AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 5df2d639..d410555d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: - // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 - if account.Type == AccountTypeOAuth { + // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 + // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 + if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { // 1. 失效缓存 if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { @@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } shouldDisable = true } else { - // 非 OAuth 账号(APIKey):保持原有 SetError 行为 + // 非 OAuth / Antigravity OAuth:保持 SetError 行为 msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { msg = "Authentication failed (401): " + upstreamMsg @@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc s.handleAuthError(ctx, account, msg) shouldDisable = true case 403: - // 禁止访问:停止调度,记录错误 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } logger.LegacyPrintf( "service.ratelimit", "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", @@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc upstreamMsg, truncateForLog(responseBody, 1024), ) - s.handleAuthError(ctx, account, msg) - shouldDisable = true + shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody) case 429: s.handle429(ctx, account, headers, responseBody) shouldDisable = false @@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +// handle403 处理 403 Forbidden 错误 +// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; +// 其他平台保持原有 SetError 行为。 +func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + if account.Platform == PlatformAntigravity { + return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) + } + // 非 Antigravity 平台:保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true +} + +// handleAntigravity403 处理 Antigravity 平台的 403 错误 +// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) +// violation(违规封号)→ 永久 SetError(需人工处理) +// generic(通用禁止)→ 永久 SetError +func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + fbType := classifyForbiddenType(string(responseBody)) + + switch fbType { + case forbiddenTypeValidation: + // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 + msg := "Validation required (403): account needs Google verification" + if upstreamMsg != "" { + msg = "Validation required (403): " + upstreamMsg + } + if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { + msg += " | validation_url: " + validationURL + } + s.handleAuthError(ctx, account, msg) + return true + + case forbiddenTypeViolation: + // 违规封号: 永久禁用,需人工处理 + msg := "Account violation (403): terms of service violation" + if upstreamMsg != "" { + msg = "Account violation (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + + default: + // 通用 403: 保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + } +} + // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg @@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac } // 401 首次命中可临时不可调度(给 token 刷新窗口); // 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。 - if statusCode == http.StatusUnauthorized { + // Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。 + if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity { reason := account.TempUnschedulableReason // 缓存可能没有 reason,从 DB 回退读取 if reason == "" { diff --git a/backend/internal/service/ratelimit_service_401_db_fallback_test.go b/backend/internal/service/ratelimit_service_401_db_fallback_test.go index e1611425..d245b5d5 100644 --- a/backend/internal/service/ratelimit_service_401_db_fallback_test.go +++ b/backend/internal/service/ratelimit_service_401_db_fallback_test.go @@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { // Scenario: cache account has empty TempUnschedulableReason (cache miss), - // but DB account has a previous 401 record → should escalate to ErrorPolicyNone. - repo := &dbFallbackRepoStub{ - dbAccount: &Account{ - ID: 20, - TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, - }, - } - svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + // but DB account has a previous 401 record. + // Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error). + // Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules). + t.Run("gemini_escalates", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - account := &Account{ - ID: 20, - Type: AccountTypeOAuth, - Platform: PlatformAntigravity, - TempUnschedulableReason: "", // cache miss — reason is empty - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": float64(401), - "keywords": []any{"unauthorized"}, - "duration_minutes": float64(10), + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformGemini, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, }, }, - }, - } + } - result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) - require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate") + }) + + t.Run("antigravity_stays_temp", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled") + }) } func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 7bced46f..4a6e5d6c 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc } func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { - tests := []struct { - name string - platform string - }{ - {name: "gemini", platform: PlatformGemini}, - {name: "antigravity", platform: PlatformAntigravity}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 100, - Platform: tt.platform, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": 401, - "keywords": []any{"unauthorized"}, - "duration_minutes": 30, - "description": "custom rule", - }, + t.Run("gemini", func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": 401, + "keywords": []any{"unauthorized"}, + "duration_minutes": 30, + "description": "custom rule", }, }, - } + }, + } - shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) - require.True(t, shouldDisable) - require.Equal(t, 0, repo.setErrorCalls) - require.Equal(t, 1, repo.tempCalls) - require.Len(t, invalidator.accounts, 1) - }) - } + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Len(t, invalidator.accounts, 1) + }) + + t.Run("antigravity_401_uses_SetError", func(t *testing.T) { + // Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制, + // HandleUpstreamError 中走 SetError 路径。 + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Empty(t, invalidator.accounts) + }) } func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index b77867de..a8710b59 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context const minVersionDBTimeout = 5 * time.Second +// cachedBackendMode Backend Mode cache (in-process, 60s TTL) +type cachedBackendMode struct { + value bool + expiresAt int64 // unix nano +} + +var backendModeCache atomic.Value // *cachedBackendMode +var backendModeSF singleflight.Group + +const backendModeCacheTTL = 60 * time.Second +const backendModeErrorTTL = 5 * time.Second +const backendModeDBTimeout = 5 * time.Second + // DefaultSubscriptionGroupReader validates group references used by default subscriptions. type DefaultSubscriptionGroupReader interface { GetByID(ctx context.Context, id int64) (*Group, error) @@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, + SettingKeyBackendModeEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], LinuxDoOAuthEnabled: linuxDoEnabled, + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", }, nil } @@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version,omitempty"` }{ RegistrationEnabled: settings.RegistrationEnabled, @@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + BackendModeEnabled: settings.BackendModeEnabled, Version: s.version, }, nil } @@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // 分组隔离 updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) + // Backend Mode + updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled) + err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 @@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet value: settings.MinClaudeCodeVersion, expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { return value == "true" } +// IsBackendModeEnabled checks if backend mode is enabled +// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path +func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { + if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value + } + } + result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) { + if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value, nil + } + } + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout) + defer cancel() + value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + // Setting not yet created (fresh install) - default to disabled with full TTL + backendModeCache.Store(&cachedBackendMode{ + value: false, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + return false, nil + } + slog.Warn("failed to get backend_mode_enabled setting", "error", err) + backendModeCache.Store(&cachedBackendMode{ + value: false, + expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(), + }) + return false, nil + } + enabled := value == "true" + backendModeCache.Store(&cachedBackendMode{ + value: enabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + return enabled, nil + }) + if val, ok := result.(bool); ok { + return val + } + return false +} + // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) @@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } // 解析整数类型 @@ -1278,7 +1350,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, } validScopes := map[string]bool{ - BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, + BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true, } for i, rule := range settings.Rules { diff --git a/backend/internal/service/setting_service_backend_mode_test.go b/backend/internal/service/setting_service_backend_mode_test.go new file mode 100644 index 00000000..39922ec8 --- /dev/null +++ b/backend/internal/service/setting_service_backend_mode_test.go @@ -0,0 +1,199 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type bmRepoStub struct { + getValueFn func(ctx context.Context, key string) (string, error) + calls int +} + +func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) { + s.calls++ + if s.getValueFn == nil { + panic("unexpected GetValue call") + } + return s.getValueFn(ctx, key) +} + +func (s *bmRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *bmRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type bmUpdateRepoStub struct { + updates map[string]string + getValueFn func(ctx context.Context, key string) (string, error) +} + +func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if s.getValueFn == nil { + panic("unexpected GetValue call") + } + return s.getValueFn(ctx, key) +} + +func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for k, v := range settings { + s.updates[k] = v + } + return nil +} + +func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func resetBackendModeTestCache(t *testing.T) { + t.Helper() + + backendModeCache.Store((*cachedBackendMode)(nil)) + t.Cleanup(func() { + backendModeCache.Store((*cachedBackendMode)(nil)) + }) +} + +func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "false", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "", ErrSettingNotFound + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "", errors.New("db down") + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.False(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestIsBackendModeEnabled_CachesResult(t *testing.T) { + resetBackendModeTestCache(t) + + repo := &bmRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.True(t, svc.IsBackendModeEnabled(context.Background())) + require.Equal(t, 1, repo.calls) +} + +func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) { + resetBackendModeTestCache(t) + + backendModeCache.Store(&cachedBackendMode{ + value: true, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + + repo := &bmUpdateRepoStub{ + getValueFn: func(ctx context.Context, key string) (string, error) { + require.Equal(t, SettingKeyBackendModeEnabled, key) + return "true", nil + }, + } + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + BackendModeEnabled: false, + }) + require.NoError(t, err) + require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled]) + require.False(t, svc.IsBackendModeEnabled(context.Background())) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 8734e28a..29eb0a36 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -69,6 +69,9 @@ type SystemSettings struct { // 分组隔离:允许未分组 Key 调度(默认 false → 403) AllowUngroupedKeyScheduling bool + + // Backend 模式:禁用用户注册和自助服务,仅管理员可登录 + BackendModeEnabled bool } type DefaultSubscriptionSetting struct { @@ -101,6 +104,7 @@ type PublicSettings struct { CustomMenuItems string // JSON array of custom menu items LinuxDoOAuthEnabled bool + BackendModeEnabled bool Version string } @@ -198,16 +202,17 @@ const ( BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token BetaPolicyActionBlock = "block" // 拦截,直接返回错误 - BetaPolicyScopeAll = "all" // 所有账号类型 - BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号 - BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号 + BetaPolicyScopeAll = "all" // 所有账号类型 + BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号 + BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号 + BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号 ) // BetaPolicyRule 单条 Beta 策略规则 type BetaPolicyRule struct { BetaToken string `json:"beta_token"` // beta token 值 Action string `json:"action"` // "pass" | "filter" | "block" - Scope string `json:"scope"` // "all" | "oauth" | "apikey" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) } diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index b8d1691f..11699c79 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -841,6 +841,8 @@ export interface OpsAdvancedSettings { ignore_context_canceled: boolean ignore_no_available_accounts: boolean ignore_invalid_api_key_errors: boolean + display_openai_token_stats: boolean + display_alert_events: boolean auto_refresh_enabled: boolean auto_refresh_interval_seconds: number } diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 2b156ea1..b14390f3 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -40,6 +40,7 @@ export interface SystemSettings { purchase_subscription_enabled: boolean purchase_subscription_url: string sora_client_enabled: boolean + backend_mode_enabled: boolean custom_menu_items: CustomMenuItem[] // SMTP settings smtp_host: string @@ -106,6 +107,7 @@ export interface UpdateSettingsRequest { purchase_subscription_enabled?: boolean purchase_subscription_url?: string sora_client_enabled?: boolean + backend_mode_enabled?: boolean custom_menu_items?: CustomMenuItem[] smtp_host?: string smtp_port?: number @@ -316,7 +318,7 @@ export async function updateRectifierSettings( export interface BetaPolicyRule { beta_token: string action: 'pass' | 'filter' | 'block' - scope: 'all' | 'oauth' | 'apikey' + scope: 'all' | 'oauth' | 'apikey' | 'bedrock' error_message?: string } diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue index b077264d..f8fe4b47 100644 --- a/frontend/src/components/account/AccountCapacityCell.vue +++ b/frontend/src/components/account/AccountCapacityCell.vue @@ -292,17 +292,19 @@ const rpmTooltip = computed(() => { } }) -// 是否显示各维度配额(仅 apikey 类型) +// 是否显示各维度配额(apikey / bedrock 类型) +const isQuotaEligible = computed(() => props.account.type === 'apikey' || props.account.type === 'bedrock') + const showDailyQuota = computed(() => { - return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0 + return isQuotaEligible.value && (props.account.quota_daily_limit ?? 0) > 0 }) const showWeeklyQuota = computed(() => { - return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0 + return isQuotaEligible.value && (props.account.quota_weekly_limit ?? 0) > 0 }) const showTotalQuota = computed(() => { - return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0 + return isQuotaEligible.value && (props.account.quota_limit ?? 0) > 0 }) // 格式化费用显示 diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index e83eaead..fb145f98 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -36,6 +36,10 @@
+ +
+ {{ usageInfo.error }} +
+ +
+ + {{ forbiddenLabel }} + +
+ + {{ t('admin.accounts.openVerification') }} + + +
+
+ + +
+ + {{ t('admin.accounts.needsReauth') }} + +
+ + +
+ + {{ usageErrorLabel }} + +
+ -
+
@@ -816,6 +865,51 @@ const hasIneligibleTiers = computed(() => { return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0 }) +// Antigravity 403 forbidden 状态 +const isForbidden = computed(() => !!usageInfo.value?.is_forbidden) +const forbiddenType = computed(() => usageInfo.value?.forbidden_type || 'forbidden') +const validationURL = computed(() => usageInfo.value?.validation_url || '') + +// 需要重新授权(401) +const needsReauth = computed(() => !!usageInfo.value?.needs_reauth) + +// 降级错误标签(rate_limited / network_error) +const usageErrorLabel = computed(() => { + const code = usageInfo.value?.error_code + if (code === 'rate_limited') return t('admin.accounts.rateLimited') + return t('admin.accounts.usageError') +}) + +const forbiddenLabel = computed(() => { + switch (forbiddenType.value) { + case 'validation': + return t('admin.accounts.forbiddenValidation') + case 'violation': + return t('admin.accounts.forbiddenViolation') + default: + return t('admin.accounts.forbidden') + } +}) + +const forbiddenBadgeClass = computed(() => { + if (forbiddenType.value === 'validation') { + return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/40 dark:text-yellow-300' + } + return 'bg-red-100 text-red-700 dark:bg-red-900/40 dark:text-red-300' +}) + +const linkCopied = ref(false) +const copyValidationURL = async () => { + if (!validationURL.value) return + try { + await navigator.clipboard.writeText(validationURL.value) + linkCopied.value = true + setTimeout(() => { linkCopied.value = false }, 2000) + } catch { + // fallback: ignore + } +} + const loadUsage = async () => { if (!shouldFetchUsage.value) return @@ -848,18 +942,30 @@ const makeQuotaBar = ( let resetsAt: string | null = null if (startKey) { const extra = props.account.extra as Record | undefined - const startStr = extra?.[startKey] as string | undefined - if (startStr) { - const startDate = new Date(startStr) - const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000 - resetsAt = new Date(startDate.getTime() + periodMs).toISOString() + const isDaily = startKey.includes('daily') + const mode = isDaily + ? (extra?.quota_daily_reset_mode as string) || 'rolling' + : (extra?.quota_weekly_reset_mode as string) || 'rolling' + + if (mode === 'fixed') { + // Use pre-computed next reset time for fixed mode + const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at' + resetsAt = (extra?.[resetAtKey] as string) || null + } else { + // Rolling mode: compute from start + period + const startStr = extra?.[startKey] as string | undefined + if (startStr) { + const startDate = new Date(startStr) + const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000 + resetsAt = new Date(startDate.getTime() + periodMs).toISOString() + } } } return { utilization, resetsAt } } const hasApiKeyQuota = computed(() => { - if (props.account.type !== 'apikey') return false + if (props.account.type !== 'apikey' && props.account.type !== 'bedrock') return false return ( (props.account.quota_daily_limit ?? 0) > 0 || (props.account.quota_weekly_limit ?? 0) > 0 || diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 1ac96ed6..a492f6a3 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -323,35 +323,6 @@
-
@@ -956,7 +927,7 @@ -
+
+
- - + +
+ + +
-
- + + + + + +
+
-
- - -

{{ t('admin.accounts.bedrockSessionTokenHint') }}

-
+ +

{{ t('admin.accounts.bedrockRegionHint') }}

+ +
-
- -
-
- - -
-
- - -

{{ t('admin.accounts.bedrockRegionHint') }}

-
-
- -

{{ t('admin.accounts.bedrockForceGlobalHint') }}

-
- - +
- - - -
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
-
- - -
- -

- {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} - {{ t('admin.accounts.supportsAllModels') }} +

+

+ + {{ t('admin.accounts.poolModeInfo') }}

- - -
-
- - - - -
- - -
- -
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

- -
+ +

{{ t('admin.accounts.quotaLimit') }}

@@ -1634,9 +1568,21 @@ :totalLimit="editQuotaLimit" :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" + :dailyResetMode="editDailyResetMode" + :dailyResetHour="editDailyResetHour" + :weeklyResetMode="editWeeklyResetMode" + :weeklyResetDay="editWeeklyResetDay" + :weeklyResetHour="editWeeklyResetHour" + :resetTimezone="editResetTimezone" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:dailyResetMode="editDailyResetMode = $event" + @update:dailyResetHour="editDailyResetHour = $event" + @update:weeklyResetMode="editWeeklyResetMode = $event" + @update:weeklyResetDay="editWeeklyResetDay = $event" + @update:weeklyResetHour="editWeeklyResetHour = $event" + @update:resetTimezone="editResetTimezone = $event" />

@@ -3014,13 +2960,19 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -3050,16 +3002,13 @@ const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('an const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) // Bedrock credentials +const bedrockAuthMode = ref<'sigv4' | 'apikey'>('sigv4') const bedrockAccessKeyId = ref('') const bedrockSecretAccessKey = ref('') const bedrockSessionToken = ref('') const bedrockRegion = ref('us-east-1') const bedrockForceGlobal = ref(false) - -// Bedrock API Key credentials const bedrockApiKeyValue = ref('') -const bedrockApiKeyRegion = ref('us-east-1') -const bedrockApiKeyForceGlobal = ref(false) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -3343,7 +3292,8 @@ watch( bedrockSessionToken.value = '' bedrockRegion.value = 'us-east-1' bedrockForceGlobal.value = false - bedrockApiKeyForceGlobal.value = false + bedrockAuthMode.value = 'sigv4' + bedrockApiKeyValue.value = '' // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -3719,6 +3669,12 @@ const resetForm = () => { editQuotaLimit.value = null editQuotaDailyLimit.value = null editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3919,27 +3875,34 @@ const handleSubmit = async () => { appStore.showError(t('admin.accounts.pleaseEnterAccountName')) return } - if (!bedrockAccessKeyId.value.trim()) { - appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired')) - return - } - if (!bedrockSecretAccessKey.value.trim()) { - appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired')) - return - } - if (!bedrockRegion.value.trim()) { - appStore.showError(t('admin.accounts.bedrockRegionRequired')) - return - } const credentials: Record = { - aws_access_key_id: bedrockAccessKeyId.value.trim(), - aws_secret_access_key: bedrockSecretAccessKey.value.trim(), - aws_region: bedrockRegion.value.trim(), + auth_mode: bedrockAuthMode.value, + aws_region: bedrockRegion.value.trim() || 'us-east-1', } - if (bedrockSessionToken.value.trim()) { - credentials.aws_session_token = bedrockSessionToken.value.trim() + + if (bedrockAuthMode.value === 'sigv4') { + if (!bedrockAccessKeyId.value.trim()) { + appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired')) + return + } + if (!bedrockSecretAccessKey.value.trim()) { + appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired')) + return + } + credentials.aws_access_key_id = bedrockAccessKeyId.value.trim() + credentials.aws_secret_access_key = bedrockSecretAccessKey.value.trim() + if (bedrockSessionToken.value.trim()) { + credentials.aws_session_token = bedrockSessionToken.value.trim() + } + } else { + if (!bedrockApiKeyValue.value.trim()) { + appStore.showError(t('admin.accounts.bedrockApiKeyRequired')) + return + } + credentials.api_key = bedrockApiKeyValue.value.trim() } + if (bedrockForceGlobal.value) { credentials.aws_force_global = 'true' } @@ -3952,45 +3915,18 @@ const handleSubmit = async () => { credentials.model_mapping = modelMapping } + // Pool mode + if (poolModeEnabled.value) { + credentials.pool_mode = true + credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials) return } - // For Bedrock API Key type, create directly - if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') { - if (!form.name.trim()) { - appStore.showError(t('admin.accounts.pleaseEnterAccountName')) - return - } - if (!bedrockApiKeyValue.value.trim()) { - appStore.showError(t('admin.accounts.bedrockApiKeyRequired')) - return - } - - const credentials: Record = { - api_key: bedrockApiKeyValue.value.trim(), - aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1', - } - if (bedrockApiKeyForceGlobal.value) { - credentials.aws_force_global = 'true' - } - - // Model mapping - const modelMapping = buildModelMappingObject( - modelRestrictionMode.value, allowedModels.value, modelMappings.value - ) - if (modelMapping) { - credentials.model_mapping = modelMapping - } - - applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') - - await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials) - return - } - // For Antigravity upstream type, create directly if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { if (!form.name.trim()) { @@ -4233,9 +4169,9 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } - // Inject quota limits for apikey accounts + // Inject quota limits for apikey/bedrock accounts let finalExtra = extra - if (type === 'apikey') { + if (type === 'apikey' || type === 'bedrock') { const quotaExtra: Record = { ...(extra || {}) } if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { quotaExtra.quota_limit = editQuotaLimit.value @@ -4246,6 +4182,19 @@ const createAccountAndFinish = async ( if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + quotaExtra.quota_daily_reset_mode = 'fixed' + quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } + if (editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_weekly_reset_mode = 'fixed' + quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } if (Object.keys(quotaExtra).length > 0) { finalExtra = quotaExtra } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index b18e9db6..77ead160 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -563,37 +563,54 @@
- +
-
- + + + + +
+ -
-
- - -

{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}

-
-
- - -

{{ t('admin.accounts.bedrockSessionTokenHint') }}

+

{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}

+ +

{{ t('admin.accounts.bedrockRegionHint') }}

+ +
-
- -
-
- - -

{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}

-
-
- - -

{{ t('admin.accounts.bedrockRegionHint') }}

-
-
- -

{{ t('admin.accounts.bedrockForceGlobalHint') }}

-
- - +
- - - -
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
-
- - -
- -

- {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} - {{ t('admin.accounts.supportsAllModels') }} +

+

+ + {{ t('admin.accounts.poolModeInfo') }}

- - -
-
- - - - -
- - -
- -
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

@@ -1182,8 +1149,8 @@
- -
+ +

{{ t('admin.accounts.quotaLimit') }}

@@ -1194,9 +1161,21 @@ :totalLimit="editQuotaLimit" :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" + :dailyResetMode="editDailyResetMode" + :dailyResetHour="editDailyResetHour" + :weeklyResetMode="editWeeklyResetMode" + :weeklyResetDay="editWeeklyResetDay" + :weeklyResetHour="editWeeklyResetHour" + :resetTimezone="editResetTimezone" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:dailyResetMode="editDailyResetMode = $event" + @update:dailyResetHour="editDailyResetHour = $event" + @update:weeklyResetMode="editWeeklyResetMode = $event" + @update:weeklyResetDay="editWeeklyResetDay = $event" + @update:weeklyResetHour="editWeeklyResetHour = $event" + @update:resetTimezone="editResetTimezone = $event" />

@@ -1781,11 +1760,11 @@ const editBedrockSecretAccessKey = ref('') const editBedrockSessionToken = ref('') const editBedrockRegion = ref('') const editBedrockForceGlobal = ref(false) - -// Bedrock API Key credentials const editBedrockApiKeyValue = ref('') -const editBedrockApiKeyRegion = ref('') -const editBedrockApiKeyForceGlobal = ref(false) +const isBedrockAPIKeyMode = computed(() => + props.account?.type === 'bedrock' && + (props.account?.credentials as Record)?.auth_mode === 'apikey' +) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -1847,6 +1826,12 @@ const anthropicPassthroughEnabled = ref(false) const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) +const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editDailyResetHour = ref(null) +const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null) +const editWeeklyResetDay = ref(null) +const editWeeklyResetHour = ref(null) +const editResetTimezone = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 @@ -2026,18 +2011,31 @@ watch( anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true } - // Load quota limit for apikey accounts - if (newAccount.type === 'apikey') { + // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) + if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { const quotaVal = extra?.quota_limit as number | undefined editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null const dailyVal = extra?.quota_daily_limit as number | undefined editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null const weeklyVal = extra?.quota_weekly_limit as number | undefined editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + // Load quota reset mode config + editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null + editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null + editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null + editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null + editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null + editResetTimezone.value = (extra?.quota_reset_timezone as string) || null } else { editQuotaLimit.value = null editQuotaDailyLimit.value = null editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null } // Load antigravity model mapping (Antigravity 只支持映射模式) @@ -2130,11 +2128,28 @@ watch( } } else if (newAccount.type === 'bedrock' && newAccount.credentials) { const bedrockCreds = newAccount.credentials as Record - editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' - editBedrockSecretAccessKey.value = '' - editBedrockSessionToken.value = '' + + if (authMode === 'apikey') { + editBedrockApiKeyValue.value = '' + } else { + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + } + + // Load pool mode for bedrock + poolModeEnabled.value = bedrockCreds.pool_mode === true + const retryCount = bedrockCreds.pool_mode_retry_count + poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT + + // Load quota limits for bedrock + const bedrockExtra = (newAccount.extra as Record) || {} + editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null + editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null + editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null // Load model mappings for bedrock const existingMappings = bedrockCreds.model_mapping as Record | undefined @@ -2155,31 +2170,6 @@ watch( modelMappings.value = [] allowedModels.value = [] } - } else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) { - const bedrockApiKeyCreds = newAccount.credentials as Record - editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1' - editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true' - editBedrockApiKeyValue.value = '' - - // Load model mappings for bedrock-apikey - const existingMappings = bedrockApiKeyCreds.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' @@ -2727,7 +2717,6 @@ const handleSubmit = async () => { const currentCredentials = (props.account.credentials as Record) || {} const newCredentials: Record = { ...currentCredentials } - newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim() newCredentials.aws_region = editBedrockRegion.value.trim() if (editBedrockForceGlobal.value) { newCredentials.aws_force_global = 'true' @@ -2735,42 +2724,29 @@ const handleSubmit = async () => { delete newCredentials.aws_force_global } - // Only update secrets if user provided new values - if (editBedrockSecretAccessKey.value.trim()) { - newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() - } - if (editBedrockSessionToken.value.trim()) { - newCredentials.aws_session_token = editBedrockSessionToken.value.trim() - } - - // Model mapping - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) - if (modelMapping) { - newCredentials.model_mapping = modelMapping + if (isBedrockAPIKeyMode.value) { + // API Key mode: only update api_key if user provided new value + if (editBedrockApiKeyValue.value.trim()) { + newCredentials.api_key = editBedrockApiKeyValue.value.trim() + } } else { - delete newCredentials.model_mapping + // SigV4 mode + newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim() + if (editBedrockSecretAccessKey.value.trim()) { + newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() + } + if (editBedrockSessionToken.value.trim()) { + newCredentials.aws_session_token = editBedrockSessionToken.value.trim() + } } - applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') - if (!applyTempUnschedConfig(newCredentials)) { - return - } - - updatePayload.credentials = newCredentials - } else if (props.account.type === 'bedrock-apikey') { - const currentCredentials = (props.account.credentials as Record) || {} - const newCredentials: Record = { ...currentCredentials } - - newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1' - if (editBedrockApiKeyForceGlobal.value) { - newCredentials.aws_force_global = 'true' + // Pool mode + if (poolModeEnabled.value) { + newCredentials.pool_mode = true + newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) } else { - delete newCredentials.aws_force_global - } - - // Only update API key if user provided new value - if (editBedrockApiKeyValue.value.trim()) { - newCredentials.api_key = editBedrockApiKeyValue.value.trim() + delete newCredentials.pool_mode + delete newCredentials.pool_mode_retry_count } // Model mapping @@ -2980,8 +2956,8 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } - // For apikey accounts, handle quota_limit in extra - if (props.account.type === 'apikey') { + // For apikey/bedrock accounts, handle quota_limit in extra + if (props.account.type === 'apikey' || props.account.type === 'bedrock') { const currentExtra = (updatePayload.extra as Record) || (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } @@ -3000,6 +2976,28 @@ const handleSubmit = async () => { } else { delete newExtra.quota_weekly_limit } + // Quota reset mode config + if (editDailyResetMode.value === 'fixed') { + newExtra.quota_daily_reset_mode = 'fixed' + newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0 + } else { + delete newExtra.quota_daily_reset_mode + delete newExtra.quota_daily_reset_hour + } + if (editWeeklyResetMode.value === 'fixed') { + newExtra.quota_weekly_reset_mode = 'fixed' + newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1 + newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0 + } else { + delete newExtra.quota_weekly_reset_mode + delete newExtra.quota_weekly_reset_day + delete newExtra.quota_weekly_reset_hour + } + if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { + newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' + } else { + delete newExtra.quota_reset_timezone + } updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 505118ba..fdc19ad9 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -8,12 +8,24 @@ const props = defineProps<{ totalLimit: number | null dailyLimit: number | null weeklyLimit: number | null + dailyResetMode: 'rolling' | 'fixed' | null + dailyResetHour: number | null + weeklyResetMode: 'rolling' | 'fixed' | null + weeklyResetDay: number | null + weeklyResetHour: number | null + resetTimezone: string | null }>() const emit = defineEmits<{ 'update:totalLimit': [value: number | null] 'update:dailyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null] + 'update:dailyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:dailyResetHour': [value: number | null] + 'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:weeklyResetDay': [value: number | null] + 'update:weeklyResetHour': [value: number | null] + 'update:resetTimezone': [value: string | null] }>() const enabled = computed(() => @@ -35,9 +47,56 @@ watch(localEnabled, (val) => { emit('update:totalLimit', null) emit('update:dailyLimit', null) emit('update:weeklyLimit', null) + emit('update:dailyResetMode', null) + emit('update:dailyResetHour', null) + emit('update:weeklyResetMode', null) + emit('update:weeklyResetDay', null) + emit('update:weeklyResetHour', null) + emit('update:resetTimezone', null) } }) +// Whether any fixed mode is active (to show timezone selector) +const hasFixedMode = computed(() => + props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed' +) + +// Common timezone options +const timezoneOptions = [ + 'UTC', + 'Asia/Shanghai', + 'Asia/Tokyo', + 'Asia/Seoul', + 'Asia/Singapore', + 'Asia/Kolkata', + 'Asia/Dubai', + 'Europe/London', + 'Europe/Paris', + 'Europe/Berlin', + 'Europe/Moscow', + 'America/New_York', + 'America/Chicago', + 'America/Denver', + 'America/Los_Angeles', + 'America/Sao_Paulo', + 'Australia/Sydney', + 'Pacific/Auckland', +] + +// Hours for dropdown (0-23) +const hourOptions = Array.from({ length: 24 }, (_, i) => i) + +// Day of week options +const dayOptions = [ + { value: 1, key: 'monday' }, + { value: 2, key: 'tuesday' }, + { value: 3, key: 'wednesday' }, + { value: 4, key: 'thursday' }, + { value: 5, key: 'friday' }, + { value: 6, key: 'saturday' }, + { value: 0, key: 'sunday' }, +] + const onTotalInput = (e: Event) => { const raw = (e.target as HTMLInputElement).valueAsNumber emit('update:totalLimit', Number.isNaN(raw) ? null : raw) @@ -50,6 +109,25 @@ const onWeeklyInput = (e: Event) => { const raw = (e.target as HTMLInputElement).valueAsNumber emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw) } + +const onDailyModeChange = (e: Event) => { + const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed' + emit('update:dailyResetMode', val) + if (val === 'fixed') { + if (props.dailyResetHour == null) emit('update:dailyResetHour', 0) + if (!props.resetTimezone) emit('update:resetTimezone', 'UTC') + } +} + +const onWeeklyModeChange = (e: Event) => { + const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed' + emit('update:weeklyResetMode', val) + if (val === 'fixed') { + if (props.weeklyResetDay == null) emit('update:weeklyResetDay', 1) + if (props.weeklyResetHour == null) emit('update:weeklyResetHour', 0) + if (!props.resetTimezone) emit('update:resetTimezone', 'UTC') + } +} -