mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-08 01:00:21 +08:00
Merge branch 'main' into fix/enc_coot
This commit is contained in:
@@ -27,12 +27,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -81,6 +81,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +114,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -203,6 +203,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -655,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package apicompat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
|
||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
|
||||
@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
@@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMapping(mapping, requestedModel)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
@@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
|
||||
return matchAntigravityWildcard(pattern, str)
|
||||
}
|
||||
|
||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||
// 如果没有匹配,返回原始字符串
|
||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||
type patternMatch struct {
|
||||
pattern string
|
||||
@@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return requestedModel // 无匹配,返回原始模型名
|
||||
return requestedModel, false // 无匹配,返回原始模型名
|
||||
}
|
||||
|
||||
// 按 pattern 长度降序排序
|
||||
@@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
return matches[i].pattern < matches[j].pattern
|
||||
})
|
||||
|
||||
return matches[0].target
|
||||
return matches[0].target, true
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
@@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
@@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||
}
|
||||
|
||||
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
@@ -1269,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||
func (a *Account) getExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||
func (a *Account) getExtraInt(key string) int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return int(parseExtraFloat64(v))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaDailyResetMode() string {
|
||||
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaDailyResetHour() int {
|
||||
return a.getExtraInt("quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||
if a.Extra == nil {
|
||||
return 1
|
||||
}
|
||||
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||
return 1
|
||||
}
|
||||
return a.getExtraInt("quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||
return a.getExtraInt("quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||
func (a *Account) GetQuotaResetTimezone() string {
|
||||
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||
return tz
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if !after.Before(today) {
|
||||
return today.AddDate(0, 0, 1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if now.Before(today) {
|
||||
return today.AddDate(0, 0, -1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysForward := (day - currentDay + 7) % 7
|
||||
if daysForward == 0 && !after.Before(todayReset) {
|
||||
daysForward = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, daysForward)
|
||||
}
|
||||
|
||||
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysBack := (currentDay - day + 7) % 7
|
||||
if daysBack == 0 && now.Before(todayReset) {
|
||||
daysBack = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, -daysBack)
|
||||
}
|
||||
|
||||
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||
// 在保存账号配置时调用
|
||||
func ComputeQuotaResetAt(extra map[string]any) {
|
||||
now := time.Now()
|
||||
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||
if tzName == "" {
|
||||
tzName = "UTC"
|
||||
}
|
||||
tz, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
|
||||
// 日配额固定重置时间
|
||||
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_daily_reset_at")
|
||||
}
|
||||
|
||||
// 周配额固定重置时间
|
||||
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||
day := 1 // 默认周一
|
||||
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day = int(parseExtraFloat64(d))
|
||||
}
|
||||
if day < 0 || day > 6 {
|
||||
day = 1
|
||||
}
|
||||
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_weekly_reset_at")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
// 校验时区
|
||||
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||
}
|
||||
}
|
||||
// 日配额重置模式
|
||||
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 日配额重置小时
|
||||
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
// 周配额重置模式
|
||||
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 周配额重置星期几
|
||||
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day := int(parseExtraFloat64(v))
|
||||
if day < 0 || day > 6 {
|
||||
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||
}
|
||||
}
|
||||
// 周配额重置小时
|
||||
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
@@ -1291,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
expired = a.isFixedDailyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -100,6 +101,7 @@ type antigravityUsageCache struct {
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -108,11 +110,12 @@ const (
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -149,6 +152,18 @@ type AntigravityModelQuota struct {
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||
type AntigravityModelDetail struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -164,6 +179,33 @@ type UsageInfo struct {
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
|
||||
// Antigravity 账号级信息
|
||||
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||
|
||||
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||
|
||||
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
|
||||
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -648,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
// 1. 检查缓存
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
// 2. singleflight 防止并发击穿
|
||||
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(等待期间可能已被填充)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||
recalcAntigravityRemainingSeconds(usage)
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer fetchCancel()
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||
if err != nil {
|
||||
degraded := buildAntigravityDegradedUsage(err)
|
||||
enrichUsageWithAccountError(degraded, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: degraded,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: fetchResult.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return fetchResult.UsageInfo, nil
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
usage, ok := result.(*UsageInfo)
|
||||
if !ok || usage == nil {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
info.FiveHour.RemainingSeconds = remaining
|
||||
}
|
||||
}
|
||||
|
||||
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
if info.IsForbidden {
|
||||
return apiCacheTTL // 封号/验证状态不会很快变
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// 从错误信息推断 error_code 和状态标记
|
||||
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||
errStr := err.Error()
|
||||
switch {
|
||||
case strings.Contains(errStr, "HTTP 401") ||
|
||||
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||
strings.Contains(errStr, "invalid_grant"):
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case strings.Contains(errStr, "HTTP 429"):
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||
//
|
||||
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||
//
|
||||
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||
//
|
||||
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||
if info == nil || account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
msg := strings.ToLower(account.ErrorMessage)
|
||||
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||
return
|
||||
}
|
||||
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||
info.IsForbidden = true
|
||||
info.ForbiddenType = fbType
|
||||
info.ForbiddenReason = account.ErrorMessage
|
||||
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||
info.IsBanned = fbType == forbiddenTypeViolation
|
||||
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
info.NeedsReauth = false
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
|
||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMapping(t *testing.T) {
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
|
||||
@@ -1462,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
// 预计算固定时间重置的下次重置时间
|
||||
if account.Extra != nil {
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
@@ -1535,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
|
||||
@@ -2,12 +2,29 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", ""
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1087,6 +1087,12 @@ type TokenPair struct {
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||
type TokenPairWithUser struct {
|
||||
TokenPair
|
||||
UserRole string
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenPairWithUser{
|
||||
TokenPair: *pair,
|
||||
UserRole: user.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
|
||||
@@ -29,12 +29,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -221,6 +220,9 @@ const (
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
|
||||
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
SettingKeyBackendModeEnabled = "backend_mode_enabled"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -2173,10 +2173,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
// isAccountSchedulableForQuota 检查账号是否在配额限制内
|
||||
// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
if !account.IsAPIKeyOrBedrock() {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
@@ -3532,9 +3532,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
|
||||
case AccountTypeBedrockAPIKey:
|
||||
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -5186,7 +5184,7 @@ func (s *GatewayService) forwardBedrock(
|
||||
if account.IsBedrockAPIKey() {
|
||||
bedrockAPIKey = account.GetCredential("api_key")
|
||||
if bedrockAPIKey == "" {
|
||||
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
|
||||
return nil, fmt.Errorf("api_key not found in bedrock credentials")
|
||||
}
|
||||
} else {
|
||||
signer, err = NewBedrockSignerFromAccount(account)
|
||||
@@ -5375,8 +5373,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
@@ -5398,8 +5397,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5808,9 +5808,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
var result betaPolicyResult
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
switch rule.Action {
|
||||
@@ -5870,14 +5871,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
|
||||
}
|
||||
|
||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
|
||||
switch scope {
|
||||
case BetaPolicyScopeAll:
|
||||
return true
|
||||
case BetaPolicyScopeOAuth:
|
||||
return isOAuth
|
||||
case BetaPolicyScopeAPIKey:
|
||||
return !isOAuth
|
||||
return !isOAuth && !isBedrock
|
||||
case BetaPolicyScopeBedrock:
|
||||
return isBedrock
|
||||
default:
|
||||
return true // unknown scope → match all (fail-open)
|
||||
}
|
||||
@@ -5959,12 +5962,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
|
||||
return nil
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
tokenSet := buildBetaTokenSet(tokens)
|
||||
for _, rule := range settings.Rules {
|
||||
if rule.Action != BetaPolicyActionBlock {
|
||||
continue
|
||||
}
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
if _, present := tokenSet[rule.BetaToken]; present {
|
||||
@@ -7199,7 +7203,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -7287,7 +7291,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
|
||||
@@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
|
||||
fixIDPrefix := func(id string) string {
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
if id == "" || strings.HasPrefix(id, "fc") {
|
||||
return id
|
||||
}
|
||||
@@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
newItem["id"] = fixIDPrefix(id)
|
||||
if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") {
|
||||
newItem["id"] = fixCallIDPrefix(id)
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
@@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
if callID != "" {
|
||||
fixedCallID := fixIDPrefix(callID)
|
||||
fixedCallID := fixCallIDPrefix(callID)
|
||||
if fixedCallID != callID {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = fixedCallID
|
||||
@@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
} else {
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
fixedID := fixIDPrefix(id)
|
||||
if fixedID != id {
|
||||
ensureCopy()
|
||||
newItem["id"] = fixedID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
|
||||
@@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "fc_ref1", first["id"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc_o1", second["id"])
|
||||
require.Equal(t, "o1", second["id"])
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||
map[string]any{"type": "item_reference", "id": "rs_123"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "msg_0", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rs_123", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||
|
||||
@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
19
backend/internal/service/openai_model_mapping.go
Normal file
19
backend/internal/service/openai_model_mapping.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
70
backend/internal/service/openai_model_mapping_test.go
Normal file
70
backend/internal/service/openai_model_mapping_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
requestedModel string
|
||||
defaultMappedModel string
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "falls back to group default when account has no mapping",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "preserves exact passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "preserves wildcard passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "uses account remap when explicit target differs",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -371,6 +371,8 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
DisplayOpenAITokenStats: false,
|
||||
DisplayAlertEvents: true,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
@@ -438,7 +440,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsAdvancedSettings{}
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true by default")
|
||||
}
|
||||
if repo.setCalls != 1 {
|
||||
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
cfg.DisplayOpenAITokenStats = true
|
||||
cfg.DisplayAlertEvents = false
|
||||
|
||||
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if !updated.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if updated.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = true, want false")
|
||||
}
|
||||
|
||||
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
|
||||
}
|
||||
if !reloaded.DisplayOpenAITokenStats {
|
||||
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if reloaded.DisplayAlertEvents {
|
||||
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
legacyCfg := map[string]any{
|
||||
"data_retention": map[string]any{
|
||||
"cleanup_enabled": false,
|
||||
"cleanup_schedule": "0 2 * * *",
|
||||
"error_log_retention_days": 30,
|
||||
"minute_metrics_retention_days": 30,
|
||||
"hourly_metrics_retention_days": 30,
|
||||
},
|
||||
"aggregation": map[string]any{
|
||||
"aggregation_enabled": false,
|
||||
},
|
||||
"ignore_count_tokens_errors": true,
|
||||
"ignore_context_canceled": true,
|
||||
"ignore_no_available_accounts": false,
|
||||
"ignore_invalid_api_key_errors": false,
|
||||
"auto_refresh_enabled": false,
|
||||
"auto_refresh_interval_seconds": 30,
|
||||
}
|
||||
raw, err := json.Marshal(legacyCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy config: %v", err)
|
||||
}
|
||||
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
|
||||
}
|
||||
}
|
||||
@@ -98,6 +98,8 @@ type OpsAdvancedSettings struct {
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
|
||||
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
|
||||
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
shouldDisable = true
|
||||
} else {
|
||||
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
|
||||
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.ratelimit",
|
||||
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
|
||||
@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
upstreamMsg,
|
||||
truncateForLog(responseBody, 1024),
|
||||
)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers, responseBody)
|
||||
shouldDisable = false
|
||||
@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle403 处理 403 Forbidden 错误
|
||||
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
|
||||
// 其他平台保持原有 SetError 行为。
|
||||
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
|
||||
}
|
||||
// 非 Antigravity 平台:保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAntigravity403 处理 Antigravity 平台的 403 错误
|
||||
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
|
||||
// violation(违规封号)→ 永久 SetError(需人工处理)
|
||||
// generic(通用禁止)→ 永久 SetError
|
||||
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
fbType := classifyForbiddenType(string(responseBody))
|
||||
|
||||
switch fbType {
|
||||
case forbiddenTypeValidation:
|
||||
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
|
||||
msg := "Validation required (403): account needs Google verification"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Validation required (403): " + upstreamMsg
|
||||
}
|
||||
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
|
||||
msg += " | validation_url: " + validationURL
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
case forbiddenTypeViolation:
|
||||
// 违规封号: 永久禁用,需人工处理
|
||||
msg := "Account violation (403): terms of service violation"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Account violation (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
default:
|
||||
// 通用 403: 保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
@@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
||||
}
|
||||
// 401 首次命中可临时不可调度(给 token 刷新窗口);
|
||||
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
// Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
|
||||
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
|
||||
reason := account.TempUnschedulableReason
|
||||
// 缓存可能没有 reason,从 DB 回退读取
|
||||
if reason == "" {
|
||||
|
||||
@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
// but DB account has a previous 401 record.
|
||||
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
|
||||
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
|
||||
t.Run("gemini_escalates", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
|
||||
})
|
||||
|
||||
t.Run("antigravity_stays_temp", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||
|
||||
@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
t.Run("gemini", func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
|
||||
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
||||
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
||||
// HandleUpstreamError 中走 SetError 路径。
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
|
||||
@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
|
||||
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
|
||||
const minVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
|
||||
type cachedBackendMode struct {
|
||||
value bool
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var backendModeCache atomic.Value // *cachedBackendMode
|
||||
var backendModeSF singleflight.Group
|
||||
|
||||
const backendModeCacheTTL = 60 * time.Second
|
||||
const backendModeErrorTTL = 5 * time.Second
|
||||
const backendModeDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyCustomMenuItems,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyBackendModeEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 分组隔离
|
||||
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
|
||||
|
||||
// Backend Mode
|
||||
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
|
||||
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil {
|
||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||
@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
value: settings.MinClaudeCodeVersion,
|
||||
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
|
||||
})
|
||||
backendModeSF.Forget("backend_mode")
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: settings.BackendModeEnabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsBackendModeEnabled checks if backend mode is enabled
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
|
||||
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value
|
||||
}
|
||||
}
|
||||
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value, nil
|
||||
}
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
|
||||
defer cancel()
|
||||
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
// Setting not yet created (fresh install) - default to disabled with full TTL
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
enabled := value == "true"
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: enabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return enabled, nil
|
||||
})
|
||||
if val, ok := result.(bool); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
||||
@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -1278,7 +1350,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
|
||||
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmRepoStub struct {
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
s.calls++
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type bmUpdateRepoStub struct {
|
||||
updates map[string]string
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.updates = make(map[string]string, len(settings))
|
||||
for k, v := range settings {
|
||||
s.updates[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func resetBackendModeTestCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
t.Cleanup(func() {
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "false", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", ErrSettingNotFound
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: true,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
|
||||
repo := &bmUpdateRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
BackendModeEnabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
}
|
||||
@@ -69,6 +69,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离:允许未分组 Key 调度(默认 false → 403)
|
||||
AllowUngroupedKeyScheduling bool
|
||||
|
||||
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
BackendModeEnabled bool
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -101,6 +104,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
BackendModeEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -198,16 +202,17 @@ const (
|
||||
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
|
||||
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
|
||||
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号
|
||||
)
|
||||
|
||||
// BetaPolicyRule 单条 Beta 策略规则
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"` // beta token 值
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user