mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 16:30:22 +08:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
474165d7aa | ||
|
|
94e067a2e2 | ||
|
|
4293c89166 | ||
|
|
ec82c37da5 | ||
|
|
552a4b998a | ||
|
|
0d2061b268 | ||
|
|
8a260defc2 | ||
|
|
e14c87597a | ||
|
|
f3f19d35aa | ||
|
|
ced90e1d84 | ||
|
|
17e4033340 | ||
|
|
044d3a013d | ||
|
|
1fc9dd7b68 | ||
|
|
8147866c09 | ||
|
|
7bd1972f94 | ||
|
|
2c9dcfe27b | ||
|
|
1b79b0f3ff | ||
|
|
c637e6cf31 | ||
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b |
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
@@ -124,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
@@ -132,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageCache := service.NewUsageCache()
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
@@ -166,10 +167,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
@@ -232,7 +233,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
|
||||
@@ -512,6 +512,8 @@ func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"total_requests": ranking.TotalRequests,
|
||||
"total_tokens": ranking.TotalTokens,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
|
||||
@@ -61,6 +61,8 @@ func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
TotalRequests: 44,
|
||||
TotalTokens: 1234,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -164,6 +166,8 @@ func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
|
||||
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -16,6 +19,55 @@ type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
@@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -62,16 +114,16 @@ type CreateGroupRequest struct {
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -191,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -244,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
|
||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -137,6 +138,7 @@ type UpdateSettingsRequest struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
@@ -326,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -437,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
@@ -531,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -614,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
|
||||
@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
period := c.DefaultQuery("period", "today")
|
||||
switch period {
|
||||
|
||||
@@ -459,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -523,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
@@ -334,9 +334,13 @@ type UsageLog struct {
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
174
backend/internal/handler/endpoint.go
Normal file
174
backend/internal/handler/endpoint.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Canonical inbound / upstream endpoint paths.
|
||||
// All normalization and derivation reference this single set
|
||||
// of constants — add new paths HERE when a new API surface
|
||||
// is introduced.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
)
|
||||
|
||||
// gin.Context keys used by the middleware and helpers below.
|
||||
const (
|
||||
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
|
||||
)
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Normalization functions
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||
//
|
||||
// "/antigravity/v1/messages" → "/v1/messages"
|
||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||
// "/openai/v1/responses/foo" → "/v1/responses"
|
||||
// "/v1beta/models/gemini:gen" → "/v1beta/models"
|
||||
func NormalizeInboundEndpoint(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
switch {
|
||||
case strings.Contains(path, EndpointChatCompletions):
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
return EndpointMessages
|
||||
case strings.Contains(path, EndpointResponses):
|
||||
return EndpointResponses
|
||||
case strings.Contains(path, EndpointGeminiModels):
|
||||
return EndpointGeminiModels
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveUpstreamEndpoint determines the upstream endpoint from the
|
||||
// account platform and the normalized inbound endpoint.
|
||||
//
|
||||
// Platform-specific rules:
|
||||
// - OpenAI always forwards to /v1/responses (with optional subpath
|
||||
// such as /v1/responses/compact preserved from the raw URL).
|
||||
// - Anthropic → /v1/messages
|
||||
// - Gemini → /v1beta/models
|
||||
// - Sora → /v1/chat/completions
|
||||
// - Antigravity routes may target either Claude or Gemini, so the
|
||||
// inbound endpoint is used to distinguish.
|
||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
inbound = strings.TrimSpace(inbound)
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||
return EndpointResponses + suffix
|
||||
}
|
||||
return EndpointResponses
|
||||
|
||||
case service.PlatformAnthropic:
|
||||
return EndpointMessages
|
||||
|
||||
case service.PlatformGemini:
|
||||
return EndpointGeminiModels
|
||||
|
||||
case service.PlatformSora:
|
||||
return EndpointChatCompletions
|
||||
|
||||
case service.PlatformAntigravity:
|
||||
// Antigravity accounts serve both Claude and Gemini.
|
||||
if inbound == EndpointGeminiModels {
|
||||
return EndpointGeminiModels
|
||||
}
|
||||
return EndpointMessages
|
||||
}
|
||||
|
||||
// Unknown platform — fall back to inbound.
|
||||
return inbound
|
||||
}
|
||||
|
||||
// responsesSubpathSuffix extracts the part after "/responses" in a raw
|
||||
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
|
||||
// Returns "" when there is no meaningful suffix.
|
||||
func responsesSubpathSuffix(rawPath string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
|
||||
idx := strings.LastIndex(trimmed, "/responses")
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
suffix := trimmed[idx+len("/responses"):]
|
||||
if suffix == "" || suffix == "/" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return ""
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Middleware
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// InboundEndpointMiddleware normalizes the request path and stores the
|
||||
// canonical inbound endpoint in gin.Context so that every handler in
|
||||
// the chain can read it via GetInboundEndpoint.
|
||||
//
|
||||
// Apply this middleware to all gateway route groups.
|
||||
func InboundEndpointMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// Context helpers — used by handlers before building
|
||||
// RecordUsageInput / RecordUsageLongContextInput.
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GetInboundEndpoint returns the canonical inbound endpoint stored by
|
||||
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
|
||||
// tests), it falls back to normalizing c.FullPath() on the fly.
|
||||
func GetInboundEndpoint(c *gin.Context) string {
|
||||
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: normalize on the fly.
|
||||
path := ""
|
||||
if c != nil {
|
||||
path = c.FullPath()
|
||||
if path == "" && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
return NormalizeInboundEndpoint(path)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpoint derives the upstream endpoint from the context
|
||||
// and the account platform. Handlers call this after scheduling an
|
||||
// account, passing account.Platform.
|
||||
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
|
||||
inbound := GetInboundEndpoint(c)
|
||||
rawPath := ""
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
rawPath = c.Request.URL.Path
|
||||
}
|
||||
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
|
||||
}
|
||||
159
backend/internal/handler/endpoint_test.go
Normal file
159
backend/internal/handler/endpoint_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// NormalizeInboundEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
// Direct canonical paths.
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai, sora).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
{"/v1beta/models/*modelAction", EndpointGeminiModels},
|
||||
{"/v1/responses/*subpath", EndpointResponses},
|
||||
|
||||
// Unknown path is returned as-is.
|
||||
{"/v1/embeddings", "/v1/embeddings"},
|
||||
{"", ""},
|
||||
{" /v1/messages ", EndpointMessages},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// DeriveUpstreamEndpoint
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound string
|
||||
rawPath string
|
||||
platform string
|
||||
want string
|
||||
}{
|
||||
// Anthropic.
|
||||
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
|
||||
|
||||
// Gemini.
|
||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||
|
||||
// Sora.
|
||||
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||
|
||||
// OpenAI — always /v1/responses.
|
||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
|
||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
|
||||
|
||||
// Unknown platform — passthrough.
|
||||
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// responsesSubpathSuffix
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestResponsesSubpathSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{"/v1/responses", ""},
|
||||
{"/v1/responses/", ""},
|
||||
{"/v1/responses/compact", "/compact"},
|
||||
{"/openai/v1/responses/compact/detail", "/compact/detail"},
|
||||
{"/v1/messages", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.raw, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// InboundEndpointMiddleware + context helpers
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestInboundEndpointMiddleware(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(InboundEndpointMiddleware())
|
||||
|
||||
var captured string
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
captured = GetInboundEndpoint(c)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, EndpointMessages, captured)
|
||||
}
|
||||
|
||||
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
|
||||
|
||||
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
|
||||
got := GetInboundEndpoint(c)
|
||||
require.Equal(t, EndpointMessages, got)
|
||||
}
|
||||
|
||||
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
|
||||
|
||||
// Simulate middleware.
|
||||
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, "/v1/responses/compact", got)
|
||||
}
|
||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -435,6 +442,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -444,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -637,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -739,6 +761,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -748,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -913,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
|
||||
}
|
||||
if s := c.Query("end_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
||||
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
|
||||
}
|
||||
}
|
||||
return startTime, endTime
|
||||
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
|
||||
@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
|
||||
// unified GetUpstreamEndpoint helper produces the same results as the
|
||||
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
|
||||
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses platform fallback",
|
||||
path: "/v1/messages",
|
||||
want: EndpointResponses,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -362,6 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -738,6 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -1235,6 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
|
||||
@@ -26,6 +26,22 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
|
||||
@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
// Use half-open range [start, end), move to next calendar day start (DST-safe).
|
||||
t = t.AddDate(0, 0, 1)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// 设置结束时间为当天结束
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界(DST-safe)。
|
||||
endTime = endTime.AddDate(0, 0, 1)
|
||||
} else {
|
||||
// 使用 period 参数
|
||||
period := c.DefaultQuery("period", "today")
|
||||
|
||||
@@ -124,10 +124,68 @@ type IneligibleTier struct {
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
|
||||
type PaidTierInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
|
||||
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias PaidTierInfo
|
||||
var raw alias
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = PaidTierInfo(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AvailableCredit 表示一条 AI Credits 余额记录。
|
||||
type AvailableCredit struct {
|
||||
CreditType string `json:"creditType,omitempty"`
|
||||
CreditAmount string `json:"creditAmount,omitempty"`
|
||||
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
|
||||
}
|
||||
|
||||
// GetAmount 将 creditAmount 解析为浮点数。
|
||||
func (c *AvailableCredit) GetAmount() float64 {
|
||||
if c.CreditAmount == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
|
||||
func (c *AvailableCredit) GetMinimumAmount() float64 {
|
||||
if c.MinimumCreditAmountForUsage == "" {
|
||||
return 0
|
||||
}
|
||||
var value float64
|
||||
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
|
||||
return value
|
||||
}
|
||||
|
||||
// OnboardUserRequest onboardUser 请求
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
@@ -157,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
|
||||
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
|
||||
if r.PaidTier == nil {
|
||||
return nil
|
||||
}
|
||||
return r.PaidTier.AvailableCredits
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
|
||||
func TestGetTier_PaidTier优先(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: "g1-pro-tier"},
|
||||
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
|
||||
}
|
||||
if got := resp.GetTier(); got != "g1-pro-tier" {
|
||||
t.Errorf("应返回 paidTier: got %s", got)
|
||||
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
|
||||
func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
CurrentTier: &TierInfo{ID: "free-tier"},
|
||||
PaidTier: &TierInfo{ID: ""},
|
||||
PaidTier: &PaidTierInfo{ID: ""},
|
||||
}
|
||||
// paidTier.ID 为空时应回退到 currentTier
|
||||
if got := resp.GetTier(); got != "free-tier" {
|
||||
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableCredits(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{
|
||||
PaidTier: &PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
credits := resp.GetAvailableCredits()
|
||||
if len(credits) != 1 {
|
||||
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
|
||||
}
|
||||
if credits[0].GetAmount() != 25 {
|
||||
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
|
||||
}
|
||||
if credits[0].GetMinimumAmount() != 5 {
|
||||
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTier_两者都为nil(t *testing.T) {
|
||||
resp := &LoadCodeAssistResponse{}
|
||||
if got := resp.GetTier(); got != "" {
|
||||
|
||||
@@ -81,6 +81,15 @@ type ModelStat struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// EndpointStat represents usage statistics for a single request endpoint.
|
||||
type EndpointStat struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
type GroupStat struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
@@ -116,6 +125,8 @@ type UserSpendingRankingItem struct {
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
@@ -179,15 +190,18 @@ type UsageLogFilters struct {
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
@@ -254,7 +268,9 @@ type AccountUsageSummary struct {
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
Endpoints []EndpointStat `json:"endpoints"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
|
||||
var usageLogInsertArgTypes = [...]string{
|
||||
"bigint",
|
||||
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"boolean",
|
||||
"timestamptz",
|
||||
}
|
||||
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*37)
|
||||
args := make([]any, 0, len(keys)*38)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*36)
|
||||
args := make([]any, 0, len(preparedList)*38)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
inboundEndpoint,
|
||||
upstreamEndpoint,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
},
|
||||
@@ -2139,7 +2161,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
actual_cost,
|
||||
requests,
|
||||
tokens,
|
||||
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost
|
||||
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost,
|
||||
COALESCE(SUM(requests) OVER (), 0) as total_requests,
|
||||
COALESCE(SUM(tokens) OVER (), 0) as total_tokens
|
||||
FROM user_spend
|
||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||
LIMIT $3
|
||||
@@ -2150,7 +2174,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
actual_cost,
|
||||
requests,
|
||||
tokens,
|
||||
total_actual_cost
|
||||
total_actual_cost,
|
||||
total_requests,
|
||||
total_tokens
|
||||
FROM ranked
|
||||
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
|
||||
`
|
||||
@@ -2168,9 +2194,11 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
|
||||
ranking := make([]UserSpendingRankingItem, 0)
|
||||
totalActualCost := 0.0
|
||||
totalRequests := int64(0)
|
||||
totalTokens := int64(0)
|
||||
for rows.Next() {
|
||||
var row UserSpendingRankingItem
|
||||
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil {
|
||||
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranking = append(ranking, row)
|
||||
@@ -2182,6 +2210,8 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
return &UserSpendingRankingResponse{
|
||||
Ranking: ranking,
|
||||
TotalActualCost: totalActualCost,
|
||||
TotalRequests: totalRequests,
|
||||
TotalTokens: totalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2505,7 +2535,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -2982,7 +3012,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at <= $2
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
|
||||
stats := &UsageStats{}
|
||||
@@ -3040,7 +3070,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -3080,6 +3110,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
stats.TotalAccountCost = &totalAccountCost
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
|
||||
start := time.Unix(0, 0).UTC()
|
||||
if filters.StartTime != nil {
|
||||
start = *filters.StartTime
|
||||
}
|
||||
end := time.Now().UTC()
|
||||
if filters.EndTime != nil {
|
||||
end = *filters.EndTime
|
||||
}
|
||||
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointPathErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
|
||||
endpointPaths = []EndpointStat{}
|
||||
}
|
||||
stats.Endpoints = endpoints
|
||||
stats.UpstreamEndpoints = upstreamEndpoints
|
||||
stats.EndpointPaths = endpointPaths
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -3092,6 +3151,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// EndpointStat represents endpoint usage statistics row.
|
||||
type EndpointStat = usagestats.EndpointStat
|
||||
|
||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, endpointColumn, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
CONCAT(
|
||||
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
|
||||
' -> ',
|
||||
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
|
||||
) AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
@@ -3254,11 +3470,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
|
||||
resp = &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
Endpoints: endpoints,
|
||||
UpstreamEndpoints: upstreamEndpoints,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -3541,6 +3769,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
inboundEndpoint sql.NullString
|
||||
upstreamEndpoint sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
)
|
||||
@@ -3581,6 +3811,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&inboundEndpoint,
|
||||
&upstreamEndpoint,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
@@ -3656,6 +3888,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
if inboundEndpoint.Valid {
|
||||
log.InboundEndpoint = &inboundEndpoint.String
|
||||
}
|
||||
if upstreamEndpoint.Valid {
|
||||
log.UpstreamEndpoint = &upstreamEndpoint.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -255,10 +259,10 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600))
|
||||
|
||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||
WithArgs(start, end, 12).
|
||||
@@ -273,6 +277,8 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||
},
|
||||
TotalActualCost: 40.0,
|
||||
TotalRequests: 30,
|
||||
TotalTokens: 2600,
|
||||
}, got)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -376,6 +382,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -415,6 +423,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -454,6 +464,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
|
||||
@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
@@ -1624,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
|
||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
endpointNorm := handler.InboundEndpointMiddleware()
|
||||
|
||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(clientRequestID)
|
||||
gateway.Use(opsErrorLogger)
|
||||
gateway.Use(endpointNorm)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(clientRequestID)
|
||||
gemini.Use(opsErrorLogger)
|
||||
gemini.Use(endpointNorm)
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(requireGroupGoogle)
|
||||
{
|
||||
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1.Use(bodyLimit)
|
||||
antigravityV1.Use(clientRequestID)
|
||||
antigravityV1.Use(opsErrorLogger)
|
||||
antigravityV1.Use(endpointNorm)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
antigravityV1.Use(requireGroupAnthropic)
|
||||
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(clientRequestID)
|
||||
antigravityV1Beta.Use(opsErrorLogger)
|
||||
antigravityV1Beta.Use(endpointNorm)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(requireGroupGoogle)
|
||||
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
|
||||
soraV1.Use(soraBodyLimit)
|
||||
soraV1.Use(clientRequestID)
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(endpointNorm)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
|
||||
@@ -901,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。
|
||||
func (a *Account) IsOveragesEnabled() bool {
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["allow_overages"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
||||
//
|
||||
// 新字段:accounts.extra.openai_passthrough。
|
||||
|
||||
@@ -45,6 +45,8 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
@@ -164,6 +166,13 @@ type AntigravityModelDetail struct {
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。
|
||||
type AICredit struct {
|
||||
CreditType string `json:"credit_type,omitempty"`
|
||||
Amount float64 `json:"amount,omitempty"`
|
||||
MinimumBalance float64 `json:"minimum_balance,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -187,6 +196,9 @@ type UsageInfo struct {
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity AI Credits 余额
|
||||
AICredits []AICredit `json:"ai_credits,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
|
||||
@@ -368,6 +368,10 @@ type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
type proxyQualityTarget struct {
|
||||
Target string
|
||||
URL string
|
||||
@@ -445,10 +449,6 @@ type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
// 限额字段:0 和 nil 都表示"无限制"
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
||||
// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零)
|
||||
func normalizeLimit(limit *float64) *float64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
if limit == nil || *limit < 0 {
|
||||
return nil
|
||||
}
|
||||
return limit
|
||||
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
// 前端始终发送这三个字段,无需 nil 守卫
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
@@ -1521,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wasOveragesEnabled := account.IsOveragesEnabled()
|
||||
|
||||
if input.Name != "" {
|
||||
account.Name = input.Name
|
||||
@@ -1542,6 +1538,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() {
|
||||
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||
// 清除 AICredits 限流 key
|
||||
if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok {
|
||||
delete(rawLimits, creditsExhaustedKey)
|
||||
}
|
||||
}
|
||||
if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() {
|
||||
delete(account.Extra, modelRateLimitsKey)
|
||||
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
|
||||
}
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
|
||||
123
backend/internal/service/admin_service_overages_test.go
Normal file
123
backend/internal/service/admin_service_overages_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type updateAccountOveragesRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
|
||||
accountID := int64(101)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.False(t, updated.IsOveragesEnabled())
|
||||
|
||||
// 关闭 overages 后,AICredits key 应被清除
|
||||
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if ok {
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
|
||||
}
|
||||
// 普通模型限流应保留
|
||||
require.True(t, ok)
|
||||
_, exists := rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
}
|
||||
|
||||
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
|
||||
accountID := int64(102)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"allow_overages": true,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.True(t, updated.IsOveragesEnabled())
|
||||
|
||||
_, exists := repo.account.Extra[modelRateLimitsKey]
|
||||
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
||||
}
|
||||
234
backend/internal/service/antigravity_credits_overages.go
Normal file
234
backend/internal/service/antigravity_credits_overages.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
|
||||
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
||||
creditsExhaustedKey = "AICredits"
|
||||
creditsExhaustedDuration = 5 * time.Hour
|
||||
)
|
||||
|
||||
type antigravity429Category string
|
||||
|
||||
const (
|
||||
antigravity429Unknown antigravity429Category = "unknown"
|
||||
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||
)
|
||||
|
||||
var (
|
||||
antigravityQuotaExhaustedKeywords = []string{
|
||||
"quota_exhausted",
|
||||
"quota exhausted",
|
||||
}
|
||||
|
||||
creditsExhaustedKeywords = []string{
|
||||
"google_one_ai",
|
||||
"insufficient credit",
|
||||
"insufficient credits",
|
||||
"not enough credit",
|
||||
"not enough credits",
|
||||
"credit exhausted",
|
||||
"credits exhausted",
|
||||
"credit balance",
|
||||
"minimumcreditamountforusage",
|
||||
"minimum credit amount for usage",
|
||||
"minimum credit",
|
||||
}
|
||||
)
|
||||
|
||||
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
|
||||
func (a *Account) isCreditsExhausted() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
return a.isRateLimitActiveForKey(creditsExhaustedKey)
|
||||
}
|
||||
|
||||
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
|
||||
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 {
|
||||
return
|
||||
}
|
||||
resetAt := time.Now().Add(creditsExhaustedDuration)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
return
|
||||
}
|
||||
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
|
||||
account.ID, resetAt.UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
|
||||
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 || account.Extra == nil {
|
||||
return
|
||||
}
|
||||
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
|
||||
return
|
||||
}
|
||||
delete(rawLimits, creditsExhaustedKey)
|
||||
account.Extra[modelRateLimitsKey] = rawLimits
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||
modelRateLimitsKey: rawLimits,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
|
||||
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
if len(body) == 0 {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
lowerBody := strings.ToLower(string(body))
|
||||
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||
if strings.Contains(lowerBody, keyword) {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
return antigravity429Unknown
|
||||
}
|
||||
|
||||
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
|
||||
func injectEnabledCreditTypes(body []byte) []byte {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil
|
||||
}
|
||||
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
|
||||
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
|
||||
modelKey := strings.TrimSpace(upstreamModelName)
|
||||
if modelKey != "" {
|
||||
return modelKey
|
||||
}
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) != "" {
|
||||
return modelKey
|
||||
}
|
||||
return resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
|
||||
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||
if reqErr != nil || resp == nil {
|
||||
return false
|
||||
}
|
||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||
return false
|
||||
}
|
||||
if isURLLevelRateLimit(respBody) {
|
||||
return false
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||
return false
|
||||
}
|
||||
bodyLower := strings.ToLower(string(respBody))
|
||||
for _, keyword := range creditsExhaustedKeywords {
|
||||
if strings.Contains(bodyLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type creditsOveragesRetryResult struct {
|
||||
handled bool
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
|
||||
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
|
||||
p antigravityRetryLoopParams,
|
||||
baseURL string,
|
||||
modelName string,
|
||||
waitDuration time.Duration,
|
||||
originalStatusCode int,
|
||||
respBody []byte,
|
||||
) *creditsOveragesRetryResult {
|
||||
creditsBody := injectEnabledCreditTypes(p.body)
|
||||
if creditsBody == nil {
|
||||
return &creditsOveragesRetryResult{handled: false}
|
||||
}
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, modelKey, p.account.ID)
|
||||
|
||||
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
|
||||
p.prefix, modelKey, p.account.ID, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
|
||||
s.clearCreditsExhausted(p.ctx, p.account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
|
||||
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
|
||||
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
|
||||
}
|
||||
|
||||
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
|
||||
ctx context.Context,
|
||||
prefix string,
|
||||
modelKey string,
|
||||
account *Account,
|
||||
creditsResp *http.Response,
|
||||
reqErr error,
|
||||
) {
|
||||
var creditsRespBody []byte
|
||||
creditsStatusCode := 0
|
||||
if creditsResp != nil {
|
||||
creditsStatusCode = creditsResp.StatusCode
|
||||
if creditsResp.Body != nil {
|
||||
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
|
||||
_ = creditsResp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
|
||||
s.setCreditsExhausted(ctx, account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
|
||||
return
|
||||
}
|
||||
if account != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
|
||||
}
|
||||
}
|
||||
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
538
backend/internal/service/antigravity_credits_overages_test.go
Normal file
@@ -0,0 +1,538 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyAntigravity429(t *testing.T) {
|
||||
t.Run("明确配额耗尽", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("结构化限流", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("未知429", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
|
||||
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.True(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Name: "acc-101",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-opus-4-6": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-opus-4-6",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Nil(t, result.switchError)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Name: "acc-102",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Name: "acc-103",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Name: "acc-104",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// 模型限流 + 积分耗尽 → 应触发切号错误
|
||||
require.Error(t, err)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Name: "acc-105",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
|
||||
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
|
||||
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
func TestShouldMarkCreditsExhausted(t *testing.T) {
|
||||
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
|
||||
})
|
||||
|
||||
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
|
||||
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("5xx 响应不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusInternalServerError}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("URL 级限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("含 credits 关键词时标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
for _, keyword := range []string{
|
||||
"Insufficient GOOGLE_ONE_AI credits",
|
||||
"insufficient credit balance",
|
||||
"not enough credits for this request",
|
||||
"Credits exhausted",
|
||||
"minimumCreditAmountForUsage requirement not met",
|
||||
} {
|
||||
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
|
||||
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
body := []byte(`{"error":{"message":"permission denied"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||
t.Run("正常 JSON 注入成功", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `"enabledCreditTypes"`)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
})
|
||||
|
||||
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
|
||||
})
|
||||
|
||||
t.Run("空 body 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte{}))
|
||||
})
|
||||
|
||||
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
|
||||
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
require.NotContains(t, string(result), `OLD`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCreditsExhausted(t *testing.T) {
|
||||
t.Run("account 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), nil)
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{"some_key": "value"},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 AICredits key 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc.clearCreditsExhausted(context.Background(), account)
|
||||
require.Len(t, repo.extraUpdateCalls, 1)
|
||||
// AICredits key 应被删除
|
||||
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "AICredits key 应被删除")
|
||||
// 普通模型限流应保留
|
||||
_, exists = rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
})
|
||||
}
|
||||
@@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||||
}
|
||||
|
||||
category := antigravity429Unknown
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
category = classifyAntigravity429(respBody)
|
||||
}
|
||||
|
||||
// 判断是否触发智能重试
|
||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||
|
||||
// AI Credits 超量请求:
|
||||
// 仅在上游明确返回免费配额耗尽时才允许切换到 credits。
|
||||
if resp.StatusCode == http.StatusTooManyRequests &&
|
||||
category == antigravity429QuotaExhausted &&
|
||||
p.account.IsOveragesEnabled() &&
|
||||
!p.account.isCreditsExhausted() {
|
||||
result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody)
|
||||
if result.handled && result.resp != nil {
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
resp: result.resp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||||
if shouldRateLimitModel {
|
||||
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
|
||||
@@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
|
||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||
// 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits
|
||||
overagesInjected := false
|
||||
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
|
||||
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
|
||||
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
|
||||
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
|
||||
p.body = creditsBody
|
||||
overagesInjected = true
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, p.requestedModel, p.account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// 预检查:如果账号已限流,直接返回切换信号
|
||||
if p.requestedModel != "" {
|
||||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||||
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||
if isSingleAccountRetry(p.ctx) {
|
||||
// 已注入积分的请求不再受普通模型限流预检查阻断。
|
||||
if overagesInjected {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
} else if isSingleAccountRetry(p.ctx) {
|
||||
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||
} else {
|
||||
@@ -631,6 +668,15 @@ urlFallbackLoop:
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) {
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel)
|
||||
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||
if policyErr != nil {
|
||||
|
||||
@@ -78,11 +78,11 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -90,20 +90,21 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
|
||||
// 同时返回 LoadCodeAssistResponse,以便提取 AI Credits 余额。
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", ""
|
||||
return "", "", nil
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", ""
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized
|
||||
return raw, normalized, loadResp
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
@@ -124,8 +125,8 @@ func normalizeTier(raw string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo。
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
@@ -190,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
}
|
||||
}
|
||||
|
||||
if loadResp != nil {
|
||||
for _, credit := range loadResp.GetAvailableCredits() {
|
||||
info.AICredits = append(info.AICredits, AICredit{
|
||||
CreditType: credit.CreditType,
|
||||
Amount: credit.GetAmount(),
|
||||
MinimumBalance: credit.GetMinimumAmount(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
@@ -141,7 +141,7 @@ func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
@@ -159,7 +159,7 @@ func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
@@ -193,7 +193,7 @@ func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
@@ -222,7 +222,7 @@ func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
@@ -251,7 +251,7 @@ func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
@@ -277,7 +277,7 @@ func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
@@ -298,7 +298,7 @@ func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
@@ -317,7 +317,7 @@ func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
@@ -338,7 +338,7 @@ func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
@@ -358,13 +358,38 @@ func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_AICredits(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
loadResp := &antigravity.LoadCodeAssistResponse{
|
||||
PaidTier: &antigravity.PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []antigravity.AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
|
||||
|
||||
require.Len(t, info.AICredits, 1)
|
||||
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
|
||||
require.Equal(t, 25.0, info.AICredits[0].Amount)
|
||||
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
|
||||
@@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
||||
return false
|
||||
}
|
||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
|
||||
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -76,10 +76,16 @@ type modelRateLimitCall struct {
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type extraUpdateCall struct {
|
||||
accountID int64
|
||||
updates map[string]any
|
||||
}
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
rateCalls []rateLimitCall
|
||||
modelRateLimitCalls []modelRateLimitCall
|
||||
extraUpdateCalls []extraUpdateCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
@@ -92,6 +98,11 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
s.extraUpdateCalls = append(s.extraUpdateCalls, extraUpdateCall{accountID: id, updates: updates})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||
|
||||
|
||||
@@ -32,15 +32,23 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
|
||||
|
||||
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
||||
type mockSmartRetryUpstream struct {
|
||||
responses []*http.Response
|
||||
errors []error
|
||||
callIdx int
|
||||
calls []string
|
||||
responses []*http.Response
|
||||
errors []error
|
||||
callIdx int
|
||||
calls []string
|
||||
requestBodies [][]byte
|
||||
}
|
||||
|
||||
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
idx := m.callIdx
|
||||
m.calls = append(m.calls, req.URL.String())
|
||||
if req != nil && req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
m.requestBodies = append(m.requestBodies, body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
} else {
|
||||
m.requestBodies = append(m.requestBodies, nil)
|
||||
}
|
||||
m.callIdx++
|
||||
if idx < len(m.responses) {
|
||||
return m.responses[idx], m.errors[idx]
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -17,15 +16,18 @@ const (
|
||||
antigravityBackfillCooldown = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
// AntigravityTokenCache token cache interface.
|
||||
type AntigravityTokenCache = GeminiTokenCache
|
||||
|
||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
||||
// AntigravityTokenProvider manages access_token for antigravity accounts.
|
||||
type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
||||
backfillCooldown sync.Map // key: accountID -> last attempt time
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
refreshPolicy: AntigravityProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return "", errors.New("not an antigravity account")
|
||||
}
|
||||
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
|
||||
|
||||
// upstream accounts use static api_key and never refresh oauth token.
|
||||
if account.Type == AccountTypeUpstream {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
|
||||
cacheKey := AntigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
// default policy: continue with existing token.
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
||||
if p.antigravityOAuthService == nil {
|
||||
return "", errors.New("antigravity oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
p.mergeCredentials(account, tokenInfo)
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
// "Invalid project resource name projects/"。
|
||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
||||
// Backfill project_id online when missing, with cooldown to avoid hammering.
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if p.shouldAttemptBackfill(account.ID) {
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
||||
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||
"account_id", account.ID,
|
||||
"error", updateErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
}
|
||||
|
||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
||||
// shouldAttemptBackfill checks backfill cooldown.
|
||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||
if lastAttempt, ok := v.(time.Time); ok {
|
||||
|
||||
@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
|
||||
}
|
||||
}
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键
|
||||
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
|
||||
return AntigravityTokenCacheKey(account)
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否可以刷新此账户
|
||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
|
||||
@@ -22,8 +22,9 @@ const (
|
||||
)
|
||||
|
||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
|
||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
||||
return windowStart == nil || time.Since(*windowStart) >= duration
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
|
||||
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil window start",
|
||||
name: "nil window start (treated as expired)",
|
||||
start: nil,
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active window (started 1h ago, 5h window)",
|
||||
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "zero usage with active windows",
|
||||
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -15,14 +14,17 @@ const (
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
// ClaudeTokenCache token cache interface.
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
refreshPolicy: ClaudeProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
if p.refreshPolicy.FailureTTL > 0 {
|
||||
ttl = p.refreshPolicy.FailureTTL
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
|
||||
@@ -80,6 +80,7 @@ const (
|
||||
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
|
||||
@@ -1073,7 +1073,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
}
|
||||
|
||||
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
@@ -1734,7 +1734,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
})
|
||||
|
||||
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
|
||||
@@ -2290,7 +2290,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
})
|
||||
|
||||
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
||||
|
||||
@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_ReasoningEffortPersisted(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
effort := "max"
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "effort_test",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
Model: "claude-opus-4-6",
|
||||
Duration: time.Second,
|
||||
ReasoningEffort: &effort,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1},
|
||||
User: &User{ID: 1},
|
||||
Account: &Account{ID: 1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
require.Equal(t, "max", *usageRepo.lastLog.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_ReasoningEffortNil(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "no_effort_test",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1},
|
||||
User: &User{ID: 1},
|
||||
Account: &Account{ID: 1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Nil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ type ParsedRequest struct {
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
OutputEffort string // output_config.effort(Claude API 的推理强度控制)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
|
||||
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||
parsed.ThinkingEnabled = true
|
||||
}
|
||||
|
||||
// output_config.effort: Claude API 的推理强度控制参数
|
||||
parsed.OutputEffort = strings.TrimSpace(gjson.Get(jsonStr, "output_config.effort").String())
|
||||
|
||||
// max_tokens: 仅接受整数值
|
||||
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
|
||||
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
||||
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value.
|
||||
// Returns nil for empty or unrecognized values.
|
||||
func NormalizeClaudeOutputEffort(raw string) *string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
switch value {
|
||||
case "low", "medium", "high", "max":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Thinking Budget Rectifier
|
||||
// =========================
|
||||
|
||||
@@ -972,6 +972,76 @@ func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_OutputEffort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantEffort string
|
||||
}{
|
||||
{
|
||||
name: "output_config.effort present",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":"medium"},"messages":[]}`,
|
||||
wantEffort: "medium",
|
||||
},
|
||||
{
|
||||
name: "output_config.effort max",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
|
||||
wantEffort: "max",
|
||||
},
|
||||
{
|
||||
name: "output_config without effort",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
|
||||
wantEffort: "",
|
||||
},
|
||||
{
|
||||
name: "no output_config",
|
||||
body: `{"model":"claude-opus-4-6","messages":[]}`,
|
||||
wantEffort: "",
|
||||
},
|
||||
{
|
||||
name: "effort with whitespace trimmed",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":" high "},"messages":[]}`,
|
||||
wantEffort: "high",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantEffort, parsed.OutputEffort)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeOutputEffort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want *string
|
||||
}{
|
||||
{"low", strPtr("low")},
|
||||
{"medium", strPtr("medium")},
|
||||
{"high", strPtr("high")},
|
||||
{"max", strPtr("max")},
|
||||
{"LOW", strPtr("low")},
|
||||
{"Max", strPtr("max")},
|
||||
{" medium ", strPtr("medium")},
|
||||
{"", nil},
|
||||
{"unknown", nil},
|
||||
{"xhigh", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := NormalizeClaudeOutputEffort(tt.input)
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
||||
data := buildLargeJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
|
||||
@@ -346,6 +346,9 @@ var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// ErrNoAvailableAccounts 表示没有可用的账号
|
||||
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
@@ -492,6 +495,7 @@ type ForwardResult struct {
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
ReasoningEffort *string
|
||||
|
||||
// 图片生成计费字段(图片生成模型使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
@@ -1204,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
@@ -1552,7 +1556,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
@@ -1641,7 +1645,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
@@ -2851,9 +2855,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
@@ -3089,9 +3093,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
@@ -7126,6 +7130,8 @@ type RecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
InboundEndpoint string // 入站端点(客户端请求路径)
|
||||
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
@@ -7523,6 +7529,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -7603,6 +7612,8 @@ type RecordUsageLongContextInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
InboundEndpoint string // 入站端点(客户端请求路径)
|
||||
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
@@ -7699,6 +7710,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
|
||||
@@ -3235,7 +3235,7 @@ func cleanToolSchema(schema any) any {
|
||||
for key, value := range v {
|
||||
// 跳过不支持的字段
|
||||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||
key == "additionalProperties" || key == "minLength" ||
|
||||
key == "additionalProperties" || key == "patternProperties" || key == "minLength" ||
|
||||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -15,10 +15,14 @@ const (
|
||||
geminiTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
||||
type GeminiTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewGeminiTokenProvider(
|
||||
@@ -30,9 +34,21 @@ func NewGeminiTokenProvider(
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
@@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Re-check after lock (another worker may have refreshed).
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID)
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
||||
if p.geminiOAuthService == nil {
|
||||
return "", errors.New("gemini oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
// project_id is optional now:
|
||||
// - If present: will use Code Assist API (requires project_id)
|
||||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
||||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
||||
// - If present: use Code Assist API (requires project_id)
|
||||
// - If absent: use AI Studio API with OAuth token.
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||||
|
||||
if projectID == "" && autoDetectProjectID {
|
||||
if p.geminiOAuthService == nil {
|
||||
return accessToken, nil // Fallback to AI Studio API mode
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
@@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
|
||||
@@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
|
||||
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键
|
||||
func (r *GeminiTokenRefresher) CacheKey(account *Account) string {
|
||||
return GeminiTokenCacheKey(account)
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
||||
}
|
||||
@@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
}
|
||||
|
||||
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
159
backend/internal/service/oauth_refresh_api.go
Normal file
159
backend/internal/service/oauth_refresh_api.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
|
||||
type OAuthRefreshExecutor interface {
|
||||
TokenRefresher
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
|
||||
CacheKey(account *Account) string
|
||||
}
|
||||
|
||||
const refreshLockTTL = 30 * time.Second
|
||||
|
||||
// OAuthRefreshResult 统一刷新结果
|
||||
type OAuthRefreshResult struct {
|
||||
Refreshed bool // 实际执行了刷新
|
||||
NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新)
|
||||
Account *Account // 从 DB 重新读取的最新 account
|
||||
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
|
||||
}
|
||||
|
||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
||||
type OAuthRefreshAPI struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
||||
}
|
||||
|
||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||
return &OAuthRefreshAPI{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||
//
|
||||
// 流程:
|
||||
// 1. 获取分布式锁
|
||||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 3. 二次检查是否仍需刷新
|
||||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 5. 设置 _token_version + 更新 DB
|
||||
// 6. 释放锁
|
||||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
executor OAuthRefreshExecutor,
|
||||
refreshWindow time.Duration,
|
||||
) (*OAuthRefreshResult, error) {
|
||||
cacheKey := executor.CacheKey(account)
|
||||
|
||||
// 1. 获取分布式锁
|
||||
lockAcquired := false
|
||||
if api.tokenCache != nil {
|
||||
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
|
||||
if lockErr != nil {
|
||||
// Redis 错误,降级为无锁刷新
|
||||
slog.Warn("oauth_refresh_lock_failed_degraded",
|
||||
"account_id", account.ID,
|
||||
"cache_key", cacheKey,
|
||||
"error", lockErr,
|
||||
)
|
||||
} else if !acquired {
|
||||
// 锁被其他 worker 持有
|
||||
return &OAuthRefreshResult{LockHeld: true}, nil
|
||||
} else {
|
||||
lockAcquired = true
|
||||
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token)
|
||||
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
|
||||
if err != nil {
|
||||
slog.Warn("oauth_refresh_db_reread_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
// 降级使用传入的 account
|
||||
freshAccount = account
|
||||
} else if freshAccount == nil {
|
||||
freshAccount = account
|
||||
}
|
||||
|
||||
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
|
||||
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
|
||||
return &OAuthRefreshResult{
|
||||
Account: freshAccount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 4. 执行平台特定刷新逻辑
|
||||
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
||||
if refreshErr != nil {
|
||||
return nil, refreshErr
|
||||
}
|
||||
|
||||
// 5. 设置版本号 + 更新 DB
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
freshAccount.Credentials = newCredentials
|
||||
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
||||
slog.Error("oauth_refresh_update_failed",
|
||||
"account_id", freshAccount.ID,
|
||||
"error", updateErr,
|
||||
)
|
||||
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
_ = lockAcquired // suppress unused warning when tokenCache is nil
|
||||
|
||||
return &OAuthRefreshResult{
|
||||
Refreshed: true,
|
||||
NewCredentials: newCredentials,
|
||||
Account: freshAccount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
||||
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
||||
if newCreds == nil {
|
||||
newCreds = make(map[string]any)
|
||||
}
|
||||
for k, v := range oldCreds {
|
||||
if _, exists := newCreds[k]; !exists {
|
||||
newCreds[k] = v
|
||||
}
|
||||
}
|
||||
return newCreds
|
||||
}
|
||||
|
||||
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
|
||||
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
|
||||
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"token_type": tokenInfo.TokenType,
|
||||
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
creds["scope"] = tokenInfo.Scope
|
||||
}
|
||||
return creds
|
||||
}
|
||||
395
backend/internal/service/oauth_refresh_api_test.go
Normal file
395
backend/internal/service/oauth_refresh_api_test.go
Normal file
@@ -0,0 +1,395 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- mock helpers ----------
|
||||
|
||||
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
|
||||
type refreshAPIAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account // returned by GetByID
|
||||
getByIDErr error
|
||||
updateErr error
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||
if r.getByIDErr != nil {
|
||||
return nil, r.getByIDErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
|
||||
r.updateCalls++
|
||||
return r.updateErr
|
||||
}
|
||||
|
||||
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
|
||||
type refreshAPIExecutorStub struct {
|
||||
needsRefresh bool
|
||||
credentials map[string]any
|
||||
err error
|
||||
refreshCalls int
|
||||
}
|
||||
|
||||
func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true }
|
||||
|
||||
func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool {
|
||||
return e.needsRefresh
|
||||
}
|
||||
|
||||
func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
e.refreshCalls++
|
||||
if e.err != nil {
|
||||
return nil, e.err
|
||||
}
|
||||
return e.credentials, nil
|
||||
}
|
||||
|
||||
func (e *refreshAPIExecutorStub) CacheKey(account *Account) string {
|
||||
return "test:api:" + account.Platform
|
||||
}
|
||||
|
||||
// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests.
|
||||
type refreshAPICacheStub struct {
|
||||
lockResult bool
|
||||
lockErr error
|
||||
releaseCalls int
|
||||
}
|
||||
|
||||
func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil }
|
||||
|
||||
func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) {
|
||||
return c.lockResult, c.lockErr
|
||||
}
|
||||
|
||||
func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error {
|
||||
c.releaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== RefreshIfNeeded tests ==========
|
||||
|
||||
func TestRefreshIfNeeded_Success(t *testing.T) {
|
||||
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new-token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed)
|
||||
require.NotNil(t, result.NewCredentials)
|
||||
require.Equal(t, "new-token", result.NewCredentials["access_token"])
|
||||
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
|
||||
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||
require.Equal(t, 1, cache.releaseCalls) // lock released
|
||||
require.Equal(t, 1, executor.refreshCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
|
||||
account := &Account{ID: 2, Platform: PlatformAnthropic}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: false} // lock not acquired
|
||||
executor := &refreshAPIExecutorStub{needsRefresh: true}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.LockHeld)
|
||||
require.False(t, result.Refreshed)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, executor.refreshCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) {
|
||||
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "degraded-token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed) // still refreshed (degraded mode)
|
||||
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||
require.Equal(t, 0, cache.releaseCalls) // no lock to release
|
||||
require.Equal(t, 1, executor.refreshCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) {
|
||||
account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "no-cache-token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) {
|
||||
account := &Account{ID: 5, Platform: PlatformAnthropic}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Refreshed)
|
||||
require.False(t, result.LockHeld)
|
||||
require.NotNil(t, result.Account) // returns fresh account
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, executor.refreshCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_RefreshError(t *testing.T) {
|
||||
account := &Account{ID: 6, Platform: PlatformAnthropic}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
err: errors.New("invalid_grant: token revoked"),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "invalid_grant")
|
||||
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
|
||||
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_DBUpdateError(t *testing.T) {
|
||||
account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{
|
||||
account: account,
|
||||
updateErr: errors.New("db connection lost"),
|
||||
}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "DB update failed")
|
||||
require.Equal(t, 1, repo.updateCalls) // attempted
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_DBRereadFails(t *testing.T) {
|
||||
account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{
|
||||
account: nil, // GetByID returns nil
|
||||
getByIDErr: errors.New("db timeout"),
|
||||
}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "fallback-token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed)
|
||||
require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_NilCredentials(t *testing.T) {
|
||||
account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: nil, // Refresh returns nil credentials
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed)
|
||||
require.Nil(t, result.NewCredentials)
|
||||
require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil
|
||||
}
|
||||
|
||||
// ========== MergeCredentials tests ==========
|
||||
|
||||
func TestMergeCredentials_Basic(t *testing.T) {
|
||||
old := map[string]any{"a": "1", "b": "2", "c": "3"}
|
||||
new := map[string]any{"a": "new", "d": "4"}
|
||||
|
||||
result := MergeCredentials(old, new)
|
||||
|
||||
require.Equal(t, "new", result["a"]) // new value preserved
|
||||
require.Equal(t, "2", result["b"]) // old value kept
|
||||
require.Equal(t, "3", result["c"]) // old value kept
|
||||
require.Equal(t, "4", result["d"]) // new value preserved
|
||||
}
|
||||
|
||||
func TestMergeCredentials_NilNew(t *testing.T) {
|
||||
old := map[string]any{"a": "1"}
|
||||
|
||||
result := MergeCredentials(old, nil)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "1", result["a"])
|
||||
}
|
||||
|
||||
func TestMergeCredentials_NilOld(t *testing.T) {
|
||||
new := map[string]any{"a": "1"}
|
||||
|
||||
result := MergeCredentials(nil, new)
|
||||
|
||||
require.Equal(t, "1", result["a"])
|
||||
}
|
||||
|
||||
func TestMergeCredentials_BothNil(t *testing.T) {
|
||||
result := MergeCredentials(nil, nil)
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestMergeCredentials_NewOverridesOld(t *testing.T) {
|
||||
old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"}
|
||||
new := map[string]any{"access_token": "new-token"}
|
||||
|
||||
result := MergeCredentials(old, new)
|
||||
|
||||
require.Equal(t, "new-token", result["access_token"]) // overridden
|
||||
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
|
||||
}
|
||||
|
||||
// ========== BuildClaudeAccountCredentials tests ==========
|
||||
|
||||
func TestBuildClaudeAccountCredentials_Full(t *testing.T) {
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: "at-123",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
ExpiresAt: 1700000000,
|
||||
RefreshToken: "rt-456",
|
||||
Scope: "openid",
|
||||
}
|
||||
|
||||
creds := BuildClaudeAccountCredentials(tokenInfo)
|
||||
|
||||
require.Equal(t, "at-123", creds["access_token"])
|
||||
require.Equal(t, "Bearer", creds["token_type"])
|
||||
require.Equal(t, "3600", creds["expires_in"])
|
||||
require.Equal(t, "1700000000", creds["expires_at"])
|
||||
require.Equal(t, "rt-456", creds["refresh_token"])
|
||||
require.Equal(t, "openid", creds["scope"])
|
||||
}
|
||||
|
||||
func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: "at-789",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 7200,
|
||||
ExpiresAt: 1700003600,
|
||||
}
|
||||
|
||||
creds := BuildClaudeAccountCredentials(tokenInfo)
|
||||
|
||||
require.Equal(t, "at-789", creds["access_token"])
|
||||
require.Equal(t, "Bearer", creds["token_type"])
|
||||
require.Equal(t, "7200", creds["expires_in"])
|
||||
require.Equal(t, "1700003600", creds["expires_at"])
|
||||
_, hasRefresh := creds["refresh_token"]
|
||||
_, hasScope := creds["scope"]
|
||||
require.False(t, hasRefresh, "refresh_token should not be set when empty")
|
||||
require.False(t, hasScope, "scope should not be set when empty")
|
||||
}
|
||||
|
||||
// ========== BackgroundRefreshPolicy tests ==========
|
||||
|
||||
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
|
||||
p := DefaultBackgroundRefreshPolicy()
|
||||
|
||||
require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped)
|
||||
require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped)
|
||||
}
|
||||
|
||||
func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) {
|
||||
p := BackgroundRefreshPolicy{
|
||||
OnLockHeld: BackgroundSkipAsSuccess,
|
||||
OnAlreadyRefresh: BackgroundSkipAsSuccess,
|
||||
}
|
||||
|
||||
require.NoError(t, p.handleLockHeld())
|
||||
require.NoError(t, p.handleAlreadyRefreshed())
|
||||
}
|
||||
|
||||
// ========== ProviderRefreshPolicy tests ==========
|
||||
|
||||
func TestClaudeProviderRefreshPolicy(t *testing.T) {
|
||||
p := ClaudeProviderRefreshPolicy()
|
||||
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
|
||||
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
|
||||
require.Equal(t, time.Minute, p.FailureTTL)
|
||||
}
|
||||
|
||||
func TestOpenAIProviderRefreshPolicy(t *testing.T) {
|
||||
p := OpenAIProviderRefreshPolicy()
|
||||
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
|
||||
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
|
||||
require.Equal(t, time.Minute, p.FailureTTL)
|
||||
}
|
||||
|
||||
func TestGeminiProviderRefreshPolicy(t *testing.T) {
|
||||
p := GeminiProviderRefreshPolicy()
|
||||
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
|
||||
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
|
||||
require.Equal(t, time.Duration(0), p.FailureTTL)
|
||||
}
|
||||
|
||||
func TestAntigravityProviderRefreshPolicy(t *testing.T) {
|
||||
p := AntigravityProviderRefreshPolicy()
|
||||
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
|
||||
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
|
||||
require.Equal(t, time.Duration(0), p.FailureTTL)
|
||||
}
|
||||
@@ -725,7 +725,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -226,6 +226,41 @@ func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_IncludesEndpointMetadata(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_endpoint_metadata",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 2,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1002,
|
||||
Group: &Group{RateMultiplier: 1},
|
||||
},
|
||||
User: &User{ID: 2002},
|
||||
Account: &Account{ID: 3002},
|
||||
InboundEndpoint: " /v1/chat/completions ",
|
||||
UpstreamEndpoint: " /v1/responses ",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.InboundEndpoint)
|
||||
require.Equal(t, "/v1/chat/completions", *usageRepo.lastLog.InboundEndpoint)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamEndpoint)
|
||||
require.Equal(t, "/v1/responses", *usageRepo.lastLog.UpstreamEndpoint)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
|
||||
groupID := int64(12)
|
||||
groupRate := 1.6
|
||||
|
||||
@@ -1312,7 +1312,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
@@ -1382,7 +1382,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
@@ -1489,7 +1489,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
@@ -4028,6 +4028,8 @@ type OpenAIRecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
InboundEndpoint string
|
||||
UpstreamEndpoint string
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
@@ -4106,6 +4108,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -4125,7 +4129,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
@@ -4668,3 +4671,11 @@ func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func optionalTrimmedStringPtr(raw string) *string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ const (
|
||||
openAILockWarnThresholdMs = 250
|
||||
)
|
||||
|
||||
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
|
||||
// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics.
|
||||
type OpenAITokenRuntimeMetrics struct {
|
||||
RefreshRequests int64
|
||||
RefreshSuccess int64
|
||||
@@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
||||
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
// OpenAITokenCache token cache interface.
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
||||
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
|
||||
type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
metrics *openAITokenRuntimeMetricsStore
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewOpenAITokenProvider(
|
||||
@@ -93,9 +96,21 @@ func NewOpenAITokenProvider(
|
||||
tokenCache: tokenCache,
|
||||
openAIOAuthService: openAIOAuthService,
|
||||
metrics: &openAITokenRuntimeMetricsStore{},
|
||||
refreshPolicy: OpenAIProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||
if p == nil {
|
||||
return OpenAITokenRuntimeMetrics{}
|
||||
@@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
p.ensureMetrics()
|
||||
if account == nil {
|
||||
@@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||
@@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
p.metrics.refreshRequests.Add(1)
|
||||
p.metrics.touchNow()
|
||||
|
||||
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
||||
p.metrics.lockContention.Add(1)
|
||||
p.metrics.touchNow()
|
||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||
if waitErr != nil {
|
||||
return "", waitErr
|
||||
}
|
||||
if strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else if result.Refreshed {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
p.metrics.refreshRequests.Add(1)
|
||||
p.metrics.touchNow()
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
p.metrics.lockAcquireFailure.Add(1)
|
||||
p.metrics.touchNow()
|
||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
// 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。
|
||||
p.metrics.lockContention.Add(1)
|
||||
p.metrics.touchNow()
|
||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||
@@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
if p.refreshPolicy.FailureTTL > 0 {
|
||||
ttl = p.refreshPolicy.FailureTTL
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
|
||||
@@ -467,7 +467,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"}
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: ErrNoAvailableAccounts.Error()}
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
|
||||
@@ -368,13 +368,14 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
Aggregation: OpsAggregationSettings{
|
||||
AggregationEnabled: false,
|
||||
},
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
DisplayOpenAITokenStats: false,
|
||||
DisplayAlertEvents: true,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
IgnoreInsufficientBalanceErrors: false, // 默认不忽略,余额不足可能需要关注
|
||||
DisplayOpenAITokenStats: false,
|
||||
DisplayAlertEvents: true,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -92,16 +92,17 @@ type OpsAlertRuntimeSettings struct {
|
||||
|
||||
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
||||
type OpsAdvancedSettings struct {
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
|
||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
type OpsDataRetentionSettings struct {
|
||||
|
||||
@@ -1174,7 +1174,8 @@ func hasRecoverableRuntimeState(account *Account) bool {
|
||||
if len(account.Extra) == 0 {
|
||||
return false
|
||||
}
|
||||
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
|
||||
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") ||
|
||||
hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
|
||||
}
|
||||
|
||||
func hasNonEmptyMapValue(extra map[string]any, key string) bool {
|
||||
|
||||
99
backend/internal/service/refresh_policy.go
Normal file
99
backend/internal/service/refresh_policy.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// ProviderRefreshErrorAction 定义 provider 在刷新失败时的处理动作。
|
||||
type ProviderRefreshErrorAction int
|
||||
|
||||
const (
|
||||
// ProviderRefreshErrorReturn 失败即返回错误(不降级旧 token)。
|
||||
ProviderRefreshErrorReturn ProviderRefreshErrorAction = iota
|
||||
// ProviderRefreshErrorUseExistingToken 失败后继续使用现有 token。
|
||||
ProviderRefreshErrorUseExistingToken
|
||||
)
|
||||
|
||||
// ProviderLockHeldAction 定义 provider 在刷新锁被占用时的处理动作。
|
||||
type ProviderLockHeldAction int
|
||||
|
||||
const (
|
||||
// ProviderLockHeldUseExistingToken 直接使用现有 token。
|
||||
ProviderLockHeldUseExistingToken ProviderLockHeldAction = iota
|
||||
// ProviderLockHeldWaitForCache 等待后重试缓存读取。
|
||||
ProviderLockHeldWaitForCache
|
||||
)
|
||||
|
||||
// ProviderRefreshPolicy 描述 provider 的平台差异策略。
|
||||
type ProviderRefreshPolicy struct {
|
||||
OnRefreshError ProviderRefreshErrorAction
|
||||
OnLockHeld ProviderLockHeldAction
|
||||
FailureTTL time.Duration
|
||||
}
|
||||
|
||||
func ClaudeProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||
return ProviderRefreshPolicy{
|
||||
OnRefreshError: ProviderRefreshErrorUseExistingToken,
|
||||
OnLockHeld: ProviderLockHeldWaitForCache,
|
||||
FailureTTL: time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func OpenAIProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||
return ProviderRefreshPolicy{
|
||||
OnRefreshError: ProviderRefreshErrorUseExistingToken,
|
||||
OnLockHeld: ProviderLockHeldWaitForCache,
|
||||
FailureTTL: time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func GeminiProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||
return ProviderRefreshPolicy{
|
||||
OnRefreshError: ProviderRefreshErrorReturn,
|
||||
OnLockHeld: ProviderLockHeldUseExistingToken,
|
||||
FailureTTL: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func AntigravityProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||
return ProviderRefreshPolicy{
|
||||
OnRefreshError: ProviderRefreshErrorReturn,
|
||||
OnLockHeld: ProviderLockHeldUseExistingToken,
|
||||
FailureTTL: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// BackgroundSkipAction 定义后台刷新服务在“未实际刷新”场景的计数方式。
|
||||
type BackgroundSkipAction int
|
||||
|
||||
const (
|
||||
// BackgroundSkipAsSkipped 计入 skipped(保持当前默认行为)。
|
||||
BackgroundSkipAsSkipped BackgroundSkipAction = iota
|
||||
// BackgroundSkipAsSuccess 计入 success(仅用于兼容旧统计口径时可选)。
|
||||
BackgroundSkipAsSuccess
|
||||
)
|
||||
|
||||
// BackgroundRefreshPolicy 描述后台刷新服务的调用侧策略。
|
||||
type BackgroundRefreshPolicy struct {
|
||||
OnLockHeld BackgroundSkipAction
|
||||
OnAlreadyRefresh BackgroundSkipAction
|
||||
}
|
||||
|
||||
func DefaultBackgroundRefreshPolicy() BackgroundRefreshPolicy {
|
||||
return BackgroundRefreshPolicy{
|
||||
OnLockHeld: BackgroundSkipAsSkipped,
|
||||
OnAlreadyRefresh: BackgroundSkipAsSkipped,
|
||||
}
|
||||
}
|
||||
|
||||
func (p BackgroundRefreshPolicy) handleLockHeld() error {
|
||||
if p.OnLockHeld == BackgroundSkipAsSuccess {
|
||||
return nil
|
||||
}
|
||||
return errRefreshSkipped
|
||||
}
|
||||
|
||||
func (p BackgroundRefreshPolicy) handleAlreadyRefreshed() error {
|
||||
if p.OnAlreadyRefresh == BackgroundSkipAsSuccess {
|
||||
return nil
|
||||
}
|
||||
return errRefreshSkipped
|
||||
}
|
||||
@@ -116,6 +116,15 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, e
|
||||
return s.parseSettings(settings), nil
|
||||
}
|
||||
|
||||
// GetFrontendURL 获取前端基础URL(数据库优先,fallback 到配置文件)
|
||||
func (s *SettingService) GetFrontendURL(ctx context.Context) string {
|
||||
val, err := s.settingRepo.GetValue(ctx, SettingKeyFrontendURL)
|
||||
if err == nil && strings.TrimSpace(val) != "" {
|
||||
return strings.TrimSpace(val)
|
||||
}
|
||||
return s.cfg.Server.FrontendURL
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置(无需登录)
|
||||
func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) {
|
||||
keys := []string{
|
||||
@@ -401,6 +410,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||
updates[SettingKeyFrontendURL] = settings.FrontendURL
|
||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
|
||||
@@ -767,6 +777,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]),
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||
FrontendURL: settings[SettingKeyFrontendURL],
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
|
||||
@@ -6,6 +6,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
FrontendURL string
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -16,10 +17,13 @@ import (
|
||||
type TokenRefreshService struct {
|
||||
accountRepo AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor(带 CacheKey)
|
||||
refreshPolicy BackgroundRefreshPolicy
|
||||
cfg *config.TokenRefreshConfig
|
||||
cacheInvalidator TokenCacheInvalidator
|
||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||
|
||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
@@ -43,6 +47,7 @@ func NewTokenRefreshService(
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
refreshPolicy: DefaultBackgroundRefreshPolicy(),
|
||||
cfg: &cfg.TokenRefresh,
|
||||
cacheInvalidator: cacheInvalidator,
|
||||
schedulerCache: schedulerCache,
|
||||
@@ -53,12 +58,24 @@ func NewTokenRefreshService(
|
||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||
|
||||
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
claudeRefresher,
|
||||
openAIRefresher,
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
}
|
||||
|
||||
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||
s.executors = []OAuthRefreshExecutor{
|
||||
claudeRefresher,
|
||||
openAIRefresher,
|
||||
geminiRefresher,
|
||||
agRefresher,
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -82,6 +99,16 @@ func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxy
|
||||
s.proxyRepo = proxyRepo
|
||||
}
|
||||
|
||||
// SetRefreshAPI 注入统一的 OAuth 刷新 API
|
||||
func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) {
|
||||
s.refreshAPI = api
|
||||
}
|
||||
|
||||
// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。
|
||||
func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
||||
s.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
@@ -148,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
totalAccounts := len(accounts)
|
||||
oauthAccounts := 0 // 可刷新的OAuth账号数
|
||||
needsRefresh := 0 // 需要刷新的账号数
|
||||
refreshed, failed := 0, 0
|
||||
refreshed, failed, skipped := 0, 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
// 遍历所有刷新器,找到能处理此账号的
|
||||
for _, refresher := range s.refreshers {
|
||||
for idx, refresher := range s.refreshers {
|
||||
if !refresher.CanRefresh(account) {
|
||||
continue
|
||||
}
|
||||
@@ -168,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
|
||||
needsRefresh++
|
||||
|
||||
// 获取对应的 executor
|
||||
var executor OAuthRefreshExecutor
|
||||
if idx < len(s.executors) {
|
||||
executor = s.executors[idx]
|
||||
}
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
slog.Warn("token_refresh.account_refresh_failed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
"error", err,
|
||||
)
|
||||
failed++
|
||||
if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil {
|
||||
if errors.Is(err, errRefreshSkipped) {
|
||||
skipped++
|
||||
} else {
|
||||
slog.Warn("token_refresh.account_refresh_failed",
|
||||
"account_id", account.ID,
|
||||
"account_name", account.Name,
|
||||
"error", err,
|
||||
)
|
||||
failed++
|
||||
}
|
||||
} else {
|
||||
slog.Info("token_refresh.account_refreshed",
|
||||
"account_id", account.ID,
|
||||
@@ -193,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
if needsRefresh == 0 && failed == 0 {
|
||||
slog.Debug("token_refresh.cycle_completed",
|
||||
"total", totalAccounts, "oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
|
||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed)
|
||||
} else {
|
||||
slog.Info("token_refresh.cycle_completed",
|
||||
"total", totalAccounts,
|
||||
"oauth", oauthAccounts,
|
||||
"needs_refresh", needsRefresh,
|
||||
"refreshed", refreshed,
|
||||
"skipped", skipped,
|
||||
"failed", failed,
|
||||
)
|
||||
}
|
||||
@@ -212,83 +250,42 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
var newCredentials map[string]any
|
||||
var err error
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
// 记录刷新版本时间戳,用于解决缓存一致性问题
|
||||
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
// 优先使用统一 API(带分布式锁 + DB 重读保护)
|
||||
if s.refreshAPI != nil && executor != nil {
|
||||
result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else if result.LockHeld {
|
||||
// 锁被其他 worker 持有,由调用侧策略决定如何计数
|
||||
return s.refreshPolicy.handleLockHeld()
|
||||
} else if !result.Refreshed {
|
||||
// 已被其他路径刷新,由调用侧策略决定如何计数
|
||||
return s.refreshPolicy.handleAlreadyRefreshed()
|
||||
} else {
|
||||
account = result.Account
|
||||
_ = result.NewCredentials // 统一 API 已设置 _token_version 并更新 DB,无需重复操作
|
||||
}
|
||||
} else {
|
||||
// 降级:直接调用 refresher(兼容旧路径)
|
||||
newCredentials, err = refresher.Refresh(ctx, account)
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_account_error_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
|
||||
if s.schedulerCache != nil {
|
||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||
s.ensureOpenAIPrivacy(ctx, account)
|
||||
s.postRefreshActions(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -331,6 +328,70 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等)
|
||||
func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_account_error_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
||||
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", clearErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||
if s.schedulerCache != nil {
|
||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||
s.ensureOpenAIPrivacy(ctx, account)
|
||||
}
|
||||
|
||||
// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed
|
||||
var errRefreshSkipped = fmt.Errorf("refresh skipped")
|
||||
|
||||
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
||||
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
|
||||
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
|
||||
|
||||
@@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map
|
||||
return r.credentials, nil
|
||||
}
|
||||
|
||||
func (r *tokenRefresherStub) CacheKey(account *Account) string {
|
||||
return "test:stub:" + account.Platform
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
@@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
@@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
@@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
}
|
||||
@@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
||||
@@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||
@@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
@@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to save credentials")
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
@@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
err: errors.New("refresh failed"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
||||
@@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
||||
err: errors.New("network error"), // 可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
@@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
||||
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
@@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
@@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
||||
err: errors.New("invalid_grant: token revoked"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
|
||||
})
|
||||
@@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Path A (refreshAPI) 测试用例 ==========
|
||||
|
||||
// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock
|
||||
type mockTokenCacheForRefreshAPI struct {
|
||||
lockResult bool
|
||||
lockErr error
|
||||
releaseCalls int
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) {
|
||||
return "", errors.New("not cached")
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) {
|
||||
return m.lockResult, m.lockErr
|
||||
}
|
||||
|
||||
func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error {
|
||||
m.releaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助)
|
||||
func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) {
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "refreshed-token",
|
||||
},
|
||||
}
|
||||
return service, refresher
|
||||
}
|
||||
|
||||
// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions
|
||||
func TestPathA_Success(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls) // DB 更新被调用
|
||||
require.Equal(t, 1, invalidator.calls) // 缓存失效被调用
|
||||
require.Equal(t, 1, cache.releaseCalls) // 锁被释放
|
||||
}
|
||||
|
||||
// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped
|
||||
func TestPathA_LockHeld(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占)
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.ErrorIs(t, err, errRefreshSkipped)
|
||||
require.Equal(t, 0, repo.updateCalls) // 不应更新 DB
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped
|
||||
func TestPathA_AlreadyRefreshed(t *testing.T) {
|
||||
// NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false}
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, _ := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
// 使用一个 NeedsRefresh 返回 false 的 stub
|
||||
noRefreshNeeded := &tokenRefresherStub{
|
||||
credentials: map[string]any{"access_token": "token"},
|
||||
}
|
||||
// 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型
|
||||
alwaysFreshStub := &alwaysFreshRefresherStub{}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour)
|
||||
require.ErrorIs(t, err, errRefreshSkipped)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
}
|
||||
|
||||
// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新)
|
||||
type alwaysFreshRefresherStub struct{}
|
||||
|
||||
func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true }
|
||||
func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false }
|
||||
func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
}
|
||||
func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string {
|
||||
return "test:fresh:" + account.Platform
|
||||
}
|
||||
|
||||
// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError
|
||||
func TestPathA_NonRetryableError(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, _ := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("invalid_grant: token revoked"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态
|
||||
require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error
|
||||
func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 2,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||
service.SetRefreshAPI(refreshAPI)
|
||||
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("network timeout"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error
|
||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions
|
||||
func TestPathA_DBUpdateFailed(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")}
|
||||
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||
|
||||
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "DB update failed")
|
||||
require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试
|
||||
require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -33,6 +32,11 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||
}
|
||||
}
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键
|
||||
func (r *ClaudeTokenRefresher) CacheKey(account *Account) string {
|
||||
return ClaudeTokenCacheKey(account)
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 anthropic 平台的 oauth 类型账号
|
||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||
@@ -59,24 +63,8 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// 只更新token相关字段
|
||||
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
newCredentials := BuildClaudeAccountCredentials(tokenInfo)
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -97,6 +85,11 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo
|
||||
}
|
||||
}
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键
|
||||
func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
|
||||
return OpenAITokenCacheKey(account)
|
||||
}
|
||||
|
||||
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
|
||||
// 用于在 Token 刷新时同步更新 sora_accounts 表
|
||||
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
|
||||
@@ -137,13 +130,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
|
||||
// 使用服务提供的方法构建新凭证,并保留原有字段
|
||||
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// 保留原有credentials中非token相关字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||
if r.accountRepo != nil && r.syncLinkedSora {
|
||||
|
||||
@@ -100,9 +100,14 @@ type UsageLog struct {
|
||||
Model string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
|
||||
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low" / "medium" / "high" / "xhigh"; Claude: "low" / "medium" / "high" / "max".
|
||||
// Nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string
|
||||
|
||||
GroupID *int64
|
||||
SubscriptionID *int64
|
||||
|
||||
@@ -51,16 +51,77 @@ func ProvideTokenRefreshService(
|
||||
tempUnschedCache TempUnschedCache,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
proxyRepo ProxyRepository,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||
svc.SetRefreshAPI(refreshAPI)
|
||||
// 调用侧显式注入后台刷新策略,避免策略漂移
|
||||
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideClaudeTokenProvider creates ClaudeTokenProvider with OAuthRefreshAPI injection
|
||||
func ProvideClaudeTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
oauthService *OAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *ClaudeTokenProvider {
|
||||
p := NewClaudeTokenProvider(accountRepo, tokenCache, oauthService)
|
||||
executor := NewClaudeTokenRefresher(oauthService)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(ClaudeProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
// ProvideOpenAITokenProvider creates OpenAITokenProvider with OAuthRefreshAPI injection
|
||||
func ProvideOpenAITokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *OpenAITokenProvider {
|
||||
p := NewOpenAITokenProvider(accountRepo, tokenCache, openaiOAuthService)
|
||||
executor := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(OpenAIProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
// ProvideGeminiTokenProvider creates GeminiTokenProvider with OAuthRefreshAPI injection
|
||||
func ProvideGeminiTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *GeminiTokenProvider {
|
||||
p := NewGeminiTokenProvider(accountRepo, tokenCache, geminiOAuthService)
|
||||
executor := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
// ProvideAntigravityTokenProvider creates AntigravityTokenProvider with OAuthRefreshAPI injection
|
||||
func ProvideAntigravityTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *AntigravityTokenProvider {
|
||||
p := NewAntigravityTokenProvider(accountRepo, tokenCache, antigravityOAuthService)
|
||||
executor := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||
p.SetRefreshAPI(refreshAPI, executor)
|
||||
p.SetRefreshPolicy(AntigravityProviderRefreshPolicy())
|
||||
return p
|
||||
}
|
||||
|
||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||
@@ -375,11 +436,12 @@ var ProviderSet = wire.NewSet(
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewOAuthRefreshAPI,
|
||||
ProvideGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
NewAntigravityTokenProvider,
|
||||
NewOpenAITokenProvider,
|
||||
NewClaudeTokenProvider,
|
||||
ProvideAntigravityTokenProvider,
|
||||
ProvideOpenAITokenProvider,
|
||||
ProvideClaudeTokenProvider,
|
||||
NewAntigravityGatewayService,
|
||||
ProvideRateLimitService,
|
||||
NewAccountUsageService,
|
||||
|
||||
5
backend/migrations/074_add_usage_log_endpoints.sql
Normal file
5
backend/migrations/074_add_usage_log_endpoints.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
-- Add endpoint tracking fields to usage_logs.
|
||||
-- inbound_endpoint: client-facing API route (e.g. /v1/chat/completions, /v1/messages, /v1/responses)
|
||||
-- upstream_endpoint: normalized upstream route (e.g. /v1/responses)
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(128);
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(128);
|
||||
@@ -841,6 +841,7 @@ export interface OpsAdvancedSettings {
|
||||
ignore_context_canceled: boolean
|
||||
ignore_no_available_accounts: boolean
|
||||
ignore_invalid_api_key_errors: boolean
|
||||
ignore_insufficient_balance_errors: boolean
|
||||
display_openai_token_stats: boolean
|
||||
display_alert_events: boolean
|
||||
auto_refresh_enabled: boolean
|
||||
|
||||
@@ -21,6 +21,7 @@ export interface SystemSettings {
|
||||
registration_email_suffix_whitelist: string[]
|
||||
promo_code_enabled: boolean
|
||||
password_reset_enabled: boolean
|
||||
frontend_url: string
|
||||
invitation_code_enabled: boolean
|
||||
totp_enabled: boolean // TOTP 双因素认证
|
||||
totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
|
||||
@@ -91,6 +92,7 @@ export interface UpdateSettingsRequest {
|
||||
registration_email_suffix_whitelist?: string[]
|
||||
promo_code_enabled?: boolean
|
||||
password_reset_enabled?: boolean
|
||||
frontend_url?: string
|
||||
invitation_code_enabled?: boolean
|
||||
totp_enabled?: boolean // TOTP 双因素认证
|
||||
default_balance?: number
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import { apiClient } from '../client'
|
||||
import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types'
|
||||
import type { EndpointStat } from '@/types'
|
||||
|
||||
// ==================== Types ====================
|
||||
|
||||
@@ -18,6 +19,9 @@ export interface AdminUsageStatsResponse {
|
||||
total_actual_cost: number
|
||||
total_account_cost?: number
|
||||
average_duration_ms: number
|
||||
endpoints?: EndpointStat[]
|
||||
upstream_endpoints?: EndpointStat[]
|
||||
endpoint_paths?: EndpointStat[]
|
||||
}
|
||||
|
||||
export interface SimpleUser {
|
||||
|
||||
@@ -446,6 +446,18 @@
|
||||
|
||||
<!-- Model Distribution -->
|
||||
<ModelDistributionChart :model-stats="stats.models" :loading="false" />
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.inboundEndpoint')"
|
||||
/>
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.upstream_endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.upstreamEndpoint')"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<!-- No Data State -->
|
||||
@@ -489,6 +501,7 @@ import { Line } from 'vue-chartjs'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageStatsResponse } from '@/types'
|
||||
|
||||
@@ -76,19 +76,39 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
||||
<!-- Model Status Indicators (普通限流 / 超量请求中) -->
|
||||
<div
|
||||
v-if="activeModelRateLimits.length > 0"
|
||||
v-if="activeModelStatuses.length > 0"
|
||||
:class="[
|
||||
activeModelRateLimits.length <= 4
|
||||
activeModelStatuses.length <= 4
|
||||
? 'flex flex-col gap-1'
|
||||
: activeModelRateLimits.length <= 8
|
||||
: activeModelStatuses.length <= 8
|
||||
? 'columns-2 gap-x-2'
|
||||
: 'columns-3 gap-x-2'
|
||||
]"
|
||||
>
|
||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative mb-1 break-inside-avoid">
|
||||
<div v-for="item in activeModelStatuses" :key="`${item.kind}-${item.model}`" class="group relative mb-1 break-inside-avoid">
|
||||
<!-- 积分已用尽 -->
|
||||
<span
|
||||
v-if="item.kind === 'credits_exhausted'"
|
||||
class="inline-flex items-center gap-1 rounded bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400"
|
||||
>
|
||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
||||
{{ t('admin.accounts.status.creditsExhausted') }}
|
||||
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
|
||||
</span>
|
||||
<!-- 正在走积分(模型限流但积分可用)-->
|
||||
<span
|
||||
v-else-if="item.kind === 'credits_active'"
|
||||
class="inline-flex items-center gap-1 rounded bg-amber-100 px-1.5 py-0.5 text-xs font-medium text-amber-700 dark:bg-amber-900/30 dark:text-amber-400"
|
||||
>
|
||||
<span>⚡</span>
|
||||
{{ formatScopeName(item.model) }}
|
||||
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
|
||||
</span>
|
||||
<!-- 普通模型限流 -->
|
||||
<span
|
||||
v-else
|
||||
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
|
||||
>
|
||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
||||
@@ -99,7 +119,13 @@
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 w-56 -translate-x-1/2 whitespace-normal rounded bg-gray-900 px-3 py-2 text-center text-xs leading-relaxed text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.modelRateLimitedUntil', { model: formatScopeName(item.model), time: formatTime(item.reset_at) }) }}
|
||||
{{
|
||||
item.kind === 'credits_exhausted'
|
||||
? t('admin.accounts.status.creditsExhaustedUntil', { time: formatTime(item.reset_at) })
|
||||
: item.kind === 'credits_active'
|
||||
? t('admin.accounts.status.modelCreditOveragesUntil', { model: formatScopeName(item.model), time: formatTime(item.reset_at) })
|
||||
: t('admin.accounts.status.modelRateLimitedUntil', { model: formatScopeName(item.model), time: formatTime(item.reset_at) })
|
||||
}}
|
||||
<div
|
||||
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700"
|
||||
></div>
|
||||
@@ -131,6 +157,7 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import type { Account } from '@/types'
|
||||
import { formatCountdown, formatDateTime, formatCountdownWithSuffix, formatTime } from '@/utils/format'
|
||||
|
||||
@@ -150,17 +177,44 @@ const isRateLimited = computed(() => {
|
||||
return new Date(props.account.rate_limit_reset_at) > new Date()
|
||||
})
|
||||
|
||||
type AccountModelStatusItem = {
|
||||
kind: 'rate_limit' | 'credits_exhausted' | 'credits_active'
|
||||
model: string
|
||||
reset_at: string
|
||||
}
|
||||
|
||||
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
|
||||
const activeModelRateLimits = computed(() => {
|
||||
const modelLimits = (props.account.extra as Record<string, unknown> | undefined)?.model_rate_limits as
|
||||
// Computed: active model statuses (普通模型限流 + 积分耗尽 + 走积分中)
|
||||
const activeModelStatuses = computed<AccountModelStatusItem[]>(() => {
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
const modelLimits = extra?.model_rate_limits as
|
||||
| Record<string, { rate_limited_at: string; rate_limit_reset_at: string }>
|
||||
| undefined
|
||||
if (!modelLimits) return []
|
||||
const now = new Date()
|
||||
return Object.entries(modelLimits)
|
||||
.filter(([, info]) => new Date(info.rate_limit_reset_at) > now)
|
||||
.map(([model, info]) => ({ model, reset_at: info.rate_limit_reset_at }))
|
||||
const items: AccountModelStatusItem[] = []
|
||||
|
||||
if (!modelLimits) return items
|
||||
|
||||
// 检查 AICredits key 是否生效(积分是否耗尽)
|
||||
const aiCreditsEntry = modelLimits['AICredits']
|
||||
const hasActiveAICredits = aiCreditsEntry && new Date(aiCreditsEntry.rate_limit_reset_at) > now
|
||||
const allowOverages = !!(extra?.allow_overages)
|
||||
|
||||
for (const [model, info] of Object.entries(modelLimits)) {
|
||||
if (new Date(info.rate_limit_reset_at) <= now) continue
|
||||
|
||||
if (model === 'AICredits') {
|
||||
// AICredits key → 积分已用尽
|
||||
items.push({ kind: 'credits_exhausted', model, reset_at: info.rate_limit_reset_at })
|
||||
} else if (allowOverages && !hasActiveAICredits) {
|
||||
// 普通模型限流 + overages 启用 + 积分可用 → 正在走积分
|
||||
items.push({ kind: 'credits_active', model, reset_at: info.rate_limit_reset_at })
|
||||
} else {
|
||||
// 普通模型限流
|
||||
items.push({ kind: 'rate_limit', model, reset_at: info.rate_limit_reset_at })
|
||||
}
|
||||
}
|
||||
|
||||
return items
|
||||
})
|
||||
|
||||
const formatScopeName = (scope: string): string => {
|
||||
@@ -182,7 +236,7 @@ const formatScopeName = (scope: string): string => {
|
||||
'gemini-3.1-pro-high': 'G3PH',
|
||||
'gemini-3.1-pro-low': 'G3PL',
|
||||
'gemini-3-pro-image': 'G3PI',
|
||||
'gemini-3.1-flash-image': 'GImage',
|
||||
'gemini-3.1-flash-image': 'G31FI',
|
||||
// 其他
|
||||
'gpt-oss-120b-medium': 'GPT120',
|
||||
'tab_flash_lite_preview': 'TabFL',
|
||||
|
||||
@@ -289,6 +289,13 @@
|
||||
:resets-at="antigravityClaudeUsageFromAPI.resetTime"
|
||||
color="amber"
|
||||
/>
|
||||
|
||||
<div v-if="aiCreditsDisplay" class="mt-1 text-[10px] text-gray-500 dark:text-gray-400">
|
||||
💳 {{ t('admin.accounts.aiCreditsBalance') }}: {{ aiCreditsDisplay }}
|
||||
</div>
|
||||
</div>
|
||||
<div v-else-if="aiCreditsDisplay" class="text-[10px] text-gray-500 dark:text-gray-400">
|
||||
💳 {{ t('admin.accounts.aiCreditsBalance') }}: {{ aiCreditsDisplay }}
|
||||
</div>
|
||||
<div v-else class="text-xs text-gray-400">-</div>
|
||||
</template>
|
||||
@@ -581,6 +588,14 @@ const antigravityClaudeUsageFromAPI = computed(() =>
|
||||
])
|
||||
)
|
||||
|
||||
const aiCreditsDisplay = computed(() => {
|
||||
const credits = usageInfo.value?.ai_credits
|
||||
if (!credits || credits.length === 0) return null
|
||||
const total = credits.reduce((sum, credit) => sum + (credit.amount ?? 0), 0)
|
||||
if (total <= 0) return null
|
||||
return total.toFixed(0)
|
||||
})
|
||||
|
||||
// Antigravity 账户类型(从 load_code_assist 响应中提取)
|
||||
const antigravityTier = computed(() => {
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
|
||||
@@ -164,27 +164,10 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Checkbox List -->
|
||||
<div class="mb-3 grid grid-cols-2 gap-2">
|
||||
<label
|
||||
v-for="model in filteredModels"
|
||||
:key="model.value"
|
||||
class="flex cursor-pointer items-center rounded-lg border p-3 transition-all hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700"
|
||||
:class="
|
||||
allowedModels.includes(model.value)
|
||||
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
|
||||
: 'border-gray-200'
|
||||
"
|
||||
>
|
||||
<input
|
||||
v-model="allowedModels"
|
||||
type="checkbox"
|
||||
:value="model.value"
|
||||
class="mr-2 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ model.label }}</span>
|
||||
</label>
|
||||
</div>
|
||||
<ModelWhitelistSelector
|
||||
v-model="allowedModels"
|
||||
:platforms="selectedPlatforms"
|
||||
/>
|
||||
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
@@ -832,8 +815,12 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
|
||||
import {
|
||||
buildModelMappingObject as buildModelMappingPayload,
|
||||
getPresetMappingsByPlatform
|
||||
} from '@/composables/useModelWhitelist'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
@@ -865,26 +852,20 @@ const allAnthropicOAuthOrSetupToken = computed(() => {
|
||||
)
|
||||
})
|
||||
|
||||
const platformModelPrefix: Record<string, string[]> = {
|
||||
anthropic: ['claude-'],
|
||||
antigravity: ['claude-', 'gemini-', 'gpt-oss-', 'tab_'],
|
||||
openai: ['gpt-'],
|
||||
gemini: ['gemini-'],
|
||||
sora: []
|
||||
}
|
||||
|
||||
const filteredModels = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return allModels
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return allModels
|
||||
return allModels.filter(m => prefixes.some(prefix => m.value.startsWith(prefix)))
|
||||
})
|
||||
|
||||
const filteredPresets = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return presetMappings
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return presetMappings
|
||||
return presetMappings.filter(m => prefixes.some(prefix => m.from.startsWith(prefix)))
|
||||
if (props.selectedPlatforms.length === 0) return []
|
||||
|
||||
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
|
||||
for (const platform of props.selectedPlatforms) {
|
||||
for (const preset of getPresetMappingsByPlatform(platform)) {
|
||||
const key = `${preset.from}=>${preset.to}`
|
||||
if (!dedupedPresets.has(key)) {
|
||||
dedupedPresets.set(key, preset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(dedupedPresets.values())
|
||||
})
|
||||
|
||||
// Model mapping type
|
||||
@@ -937,204 +918,6 @@ const umqModeOptions = computed(() => [
|
||||
{ value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') },
|
||||
])
|
||||
|
||||
// All models list (combined Anthropic + OpenAI + Gemini)
|
||||
const allModels = [
|
||||
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
||||
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||
{ value: 'claude-3-5-haiku-20241022', label: 'Claude 3.5 Haiku' },
|
||||
{ value: 'claude-haiku-4-5-20251001', label: 'Claude Haiku 4.5' },
|
||||
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
||||
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||
{ value: 'gpt-5.3-codex', label: 'GPT-5.3 Codex' },
|
||||
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
|
||||
{ value: 'gpt-5.4', label: 'GPT-5.4' },
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||
{ value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' },
|
||||
{ value: 'gemini-2.5-flash-image', label: 'Gemini 2.5 Flash Image' },
|
||||
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
||||
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
||||
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
||||
{ value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' },
|
||||
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
||||
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
||||
]
|
||||
|
||||
// Preset mappings (combined Anthropic + OpenAI + Gemini)
|
||||
const presetMappings = [
|
||||
{
|
||||
label: 'Sonnet 4',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-20250514',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet 4.5',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color:
|
||||
'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.5',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-5-20251101',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.6',
|
||||
from: 'claude-opus-4-6',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.6-thinking',
|
||||
from: 'claude-opus-4-6-thinking',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet 4.6',
|
||||
from: 'claude-sonnet-4-6',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4→4.6',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4.5→4.6',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet3.5→4.6',
|
||||
from: 'claude-3-5-sonnet-20241022',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus4.5→4.6',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus->Sonnet',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: 'Gemini 2.5 Image',
|
||||
from: 'gemini-2.5-flash-image',
|
||||
to: 'gemini-2.5-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'Gemini 3.1 Image',
|
||||
from: 'gemini-3.1-flash-image',
|
||||
to: 'gemini-3.1-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'G3 Image→3.1',
|
||||
from: 'gemini-3-pro-image',
|
||||
to: 'gemini-3.1-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Codex',
|
||||
from: 'gpt-5.3-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Spark',
|
||||
from: 'gpt-5.3-codex-spark',
|
||||
to: 'gpt-5.3-codex-spark',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.4',
|
||||
from: 'gpt-5.4',
|
||||
to: 'gpt-5.4',
|
||||
color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400'
|
||||
},
|
||||
{
|
||||
label: '5.2→5.3',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2',
|
||||
from: 'gpt-5.2-2025-12-11',
|
||||
to: 'gpt-5.2-2025-12-11',
|
||||
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2 Codex',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.2-codex',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
},
|
||||
{
|
||||
label: 'Max->Codex',
|
||||
from: 'gpt-5.1-codex-max',
|
||||
to: 'gpt-5.1-codex',
|
||||
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Preview→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-preview',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-High→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-high',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Low→3.1-Pro-Low',
|
||||
from: 'gemini-3-pro-low',
|
||||
to: 'gemini-3.1-pro-low',
|
||||
color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
},
|
||||
{
|
||||
label: '3-Flash透传',
|
||||
from: 'gemini-3-flash',
|
||||
to: 'gemini-3-flash',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: '2.5-Flash-Lite透传',
|
||||
from: 'gemini-2.5-flash-lite',
|
||||
to: 'gemini-2.5-flash-lite',
|
||||
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||
}
|
||||
]
|
||||
|
||||
// Common HTTP error codes
|
||||
const commonErrorCodes = [
|
||||
{ value: 401, label: 'Unauthorized' },
|
||||
|
||||
@@ -2449,6 +2449,33 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="form.platform === 'antigravity'" class="mt-3 flex items-center gap-2">
|
||||
<label class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
v-model="allowOverages"
|
||||
class="h-4 w-4 rounded border-gray-300 text-primary-500 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.accounts.allowOverages') }}
|
||||
</span>
|
||||
</label>
|
||||
<div class="group relative">
|
||||
<span
|
||||
class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full bg-gray-200 text-xs text-gray-500 hover:bg-gray-300 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500"
|
||||
>
|
||||
?
|
||||
</span>
|
||||
<div
|
||||
class="pointer-events-none absolute left-0 top-full z-[100] mt-1.5 w-72 rounded bg-gray-900 px-3 py-2 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.allowOveragesTooltip') }}
|
||||
<div
|
||||
class="absolute bottom-full left-3 border-4 border-transparent border-b-gray-900 dark:border-b-gray-700"
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Group Selection - 仅标准模式显示 -->
|
||||
<GroupSelector
|
||||
@@ -2991,6 +3018,7 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OF
|
||||
const codexCLIOnlyEnabled = ref(false)
|
||||
const anthropicPassthroughEnabled = ref(false)
|
||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
||||
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
||||
const soraAccountType = ref<'oauth' | 'apikey'>('oauth') // For sora: oauth or apikey (upstream)
|
||||
const upstreamBaseUrl = ref('') // For upstream type: base URL
|
||||
@@ -3017,6 +3045,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>
|
||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||
const geminiAIStudioOAuthEnabled = ref(false)
|
||||
|
||||
function buildAntigravityExtra(): Record<string, unknown> | undefined {
|
||||
const extra: Record<string, unknown> = {}
|
||||
if (mixedScheduling.value) extra.mixed_scheduling = true
|
||||
if (allowOverages.value) extra.allow_overages = true
|
||||
return Object.keys(extra).length > 0 ? extra : undefined
|
||||
}
|
||||
|
||||
const showMixedChannelWarning = ref(false)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||
null
|
||||
@@ -3282,6 +3317,7 @@ watch(
|
||||
accountCategory.value = 'oauth-based'
|
||||
antigravityAccountType.value = 'oauth'
|
||||
} else {
|
||||
allowOverages.value = false
|
||||
antigravityWhitelistModels.value = []
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
@@ -3712,6 +3748,7 @@ const resetForm = () => {
|
||||
sessionIdMaskingEnabled.value = false
|
||||
cacheTTLOverrideEnabled.value = false
|
||||
cacheTTLOverrideTarget.value = '5m'
|
||||
allowOverages.value = false
|
||||
antigravityAccountType.value = 'oauth'
|
||||
upstreamBaseUrl.value = ''
|
||||
upstreamApiKey.value = ''
|
||||
@@ -3960,7 +3997,7 @@ const handleSubmit = async () => {
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||
const extra = buildAntigravityExtra()
|
||||
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
||||
return
|
||||
}
|
||||
@@ -4706,7 +4743,7 @@ const handleAntigravityExchange = async (authCode: string) => {
|
||||
if (antigravityModelMapping) {
|
||||
credentials.model_mapping = antigravityModelMapping
|
||||
}
|
||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||
const extra = buildAntigravityExtra()
|
||||
await createAccountAndFinish('antigravity', 'oauth', credentials, extra)
|
||||
} catch (error: any) {
|
||||
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
|
||||
@@ -1610,6 +1610,33 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="account?.platform === 'antigravity'" class="mt-3 flex items-center gap-2">
|
||||
<label class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
v-model="allowOverages"
|
||||
class="h-4 w-4 rounded border-gray-300 text-primary-500 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.accounts.allowOverages') }}
|
||||
</span>
|
||||
</label>
|
||||
<div class="group relative">
|
||||
<span
|
||||
class="inline-flex h-4 w-4 cursor-help items-center justify-center rounded-full bg-gray-200 text-xs text-gray-500 hover:bg-gray-300 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500"
|
||||
>
|
||||
?
|
||||
</span>
|
||||
<div
|
||||
class="pointer-events-none absolute left-0 top-full z-[100] mt-1.5 w-72 rounded bg-gray-900 px-3 py-2 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.allowOveragesTooltip') }}
|
||||
<div
|
||||
class="absolute bottom-full left-3 border-4 border-transparent border-b-gray-900 dark:border-b-gray-700"
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Group Selection - 仅标准模式显示 -->
|
||||
@@ -1778,6 +1805,7 @@ const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
const autoPauseOnExpired = ref(false)
|
||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
||||
const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const antigravityWhitelistModels = ref<string[]>([])
|
||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
@@ -1980,8 +2008,11 @@ watch(
|
||||
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
|
||||
|
||||
// Load mixed scheduling setting (only for antigravity accounts)
|
||||
mixedScheduling.value = false
|
||||
allowOverages.value = false
|
||||
const extra = newAccount.extra as Record<string, unknown> | undefined
|
||||
mixedScheduling.value = extra?.mixed_scheduling === true
|
||||
allowOverages.value = extra?.allow_overages === true
|
||||
|
||||
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
|
||||
openaiPassthroughEnabled.value = false
|
||||
@@ -2822,7 +2853,7 @@ const handleSubmit = async () => {
|
||||
updatePayload.credentials = newCredentials
|
||||
}
|
||||
|
||||
// For antigravity accounts, handle mixed_scheduling in extra
|
||||
// For antigravity accounts, handle mixed_scheduling and allow_overages in extra
|
||||
if (props.account.platform === 'antigravity') {
|
||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||
@@ -2831,6 +2862,11 @@ const handleSubmit = async () => {
|
||||
} else {
|
||||
delete newExtra.mixed_scheduling
|
||||
}
|
||||
if (allowOverages.value) {
|
||||
newExtra.allow_overages = true
|
||||
} else {
|
||||
delete newExtra.allow_overages
|
||||
}
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,8 @@ const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: string[]
|
||||
platform: string
|
||||
platform?: string
|
||||
platforms?: string[]
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -144,11 +145,36 @@ const showDropdown = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const customModel = ref('')
|
||||
const isComposing = ref(false)
|
||||
const normalizedPlatforms = computed(() => {
|
||||
const rawPlatforms =
|
||||
props.platforms && props.platforms.length > 0
|
||||
? props.platforms
|
||||
: props.platform
|
||||
? [props.platform]
|
||||
: []
|
||||
|
||||
return Array.from(
|
||||
new Set(
|
||||
rawPlatforms
|
||||
.map(platform => platform?.trim())
|
||||
.filter((platform): platform is string => Boolean(platform))
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
const availableOptions = computed(() => {
|
||||
if (props.platform === 'sora') {
|
||||
return getModelsByPlatform('sora').map(m => ({ value: m, label: m }))
|
||||
if (normalizedPlatforms.value.length === 0) {
|
||||
return allModels
|
||||
}
|
||||
return allModels
|
||||
|
||||
const allowedModels = new Set<string>()
|
||||
for (const platform of normalizedPlatforms.value) {
|
||||
for (const model of getModelsByPlatform(platform)) {
|
||||
allowedModels.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
return allModels.filter(model => allowedModels.has(model.value))
|
||||
})
|
||||
|
||||
const filteredModels = computed(() => {
|
||||
@@ -192,10 +218,13 @@ const handleEnter = () => {
|
||||
}
|
||||
|
||||
const fillRelated = () => {
|
||||
const models = getModelsByPlatform(props.platform)
|
||||
const newModels = [...props.modelValue]
|
||||
for (const model of models) {
|
||||
if (!newModels.includes(model)) newModels.push(model)
|
||||
for (const platform of normalizedPlatforms.value) {
|
||||
for (const model of getModelsByPlatform(platform)) {
|
||||
if (!newModels.includes(model)) {
|
||||
newModels.push(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
emit('update:modelValue', newModels)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,29 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- Window stats row (above progress bar) -->
|
||||
<div
|
||||
v-if="windowStats"
|
||||
class="mb-0.5 flex items-center"
|
||||
>
|
||||
<div class="flex items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400">
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatRequests }} req
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatTokens }}
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
A ${{ formatAccountCost }}
|
||||
</span>
|
||||
<span
|
||||
v-if="windowStats?.user_cost != null"
|
||||
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
|
||||
>
|
||||
U ${{ formatUserCost }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Progress bar row -->
|
||||
<div class="flex items-center gap-1">
|
||||
<!-- Label badge (fixed width for alignment) -->
|
||||
@@ -108,4 +132,32 @@ const formatResetTime = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// Window stats formatters
|
||||
const formatRequests = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const r = props.windowStats.requests
|
||||
if (r >= 1000000) return `${(r / 1000000).toFixed(1)}M`
|
||||
if (r >= 1000) return `${(r / 1000).toFixed(1)}K`
|
||||
return r.toString()
|
||||
})
|
||||
|
||||
const formatTokens = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const t = props.windowStats.tokens
|
||||
if (t >= 1000000000) return `${(t / 1000000000).toFixed(1)}B`
|
||||
if (t >= 1000000) return `${(t / 1000000).toFixed(1)}M`
|
||||
if (t >= 1000) return `${(t / 1000).toFixed(1)}K`
|
||||
return t.toString()
|
||||
})
|
||||
|
||||
const formatAccountCost = computed(() => {
|
||||
if (!props.windowStats) return '0.00'
|
||||
return props.windowStats.cost.toFixed(2)
|
||||
})
|
||||
|
||||
const formatUserCost = computed(() => {
|
||||
if (!props.windowStats || props.windowStats.user_cost == null) return '0.00'
|
||||
return props.windowStats.user_cost.toFixed(2)
|
||||
})
|
||||
|
||||
</script>
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { mount } from '@vue/test-utils'
|
||||
import AccountStatusIndicator from '../AccountStatusIndicator.vue'
|
||||
import type { Account } from '@/types'
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
name: 'account',
|
||||
platform: 'antigravity',
|
||||
type: 'oauth',
|
||||
proxy_id: null,
|
||||
concurrency: 1,
|
||||
priority: 1,
|
||||
status: 'active',
|
||||
error_message: null,
|
||||
last_used_at: null,
|
||||
expires_at: null,
|
||||
auto_pause_on_expired: true,
|
||||
created_at: '2026-03-15T00:00:00Z',
|
||||
updated_at: '2026-03-15T00:00:00Z',
|
||||
schedulable: true,
|
||||
rate_limited_at: null,
|
||||
rate_limit_reset_at: null,
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe('AccountStatusIndicator', () => {
|
||||
it('模型限流 + overages 启用 + 无 AICredits key → 显示 ⚡ (credits_active)', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 1,
|
||||
name: 'ag-1',
|
||||
extra: {
|
||||
allow_overages: true,
|
||||
model_rate_limits: {
|
||||
'claude-sonnet-4-5': {
|
||||
rate_limited_at: '2026-03-15T00:00:00Z',
|
||||
rate_limit_reset_at: '2099-03-15T00:00:00Z'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('⚡')
|
||||
expect(wrapper.text()).toContain('CSon45')
|
||||
})
|
||||
|
||||
it('模型限流 + overages 未启用 → 普通限流样式(无 ⚡)', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 2,
|
||||
name: 'ag-2',
|
||||
extra: {
|
||||
model_rate_limits: {
|
||||
'claude-sonnet-4-5': {
|
||||
rate_limited_at: '2026-03-15T00:00:00Z',
|
||||
rate_limit_reset_at: '2099-03-15T00:00:00Z'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('CSon45')
|
||||
expect(wrapper.text()).not.toContain('⚡')
|
||||
})
|
||||
|
||||
it('AICredits key 生效 → 显示积分已用尽 (credits_exhausted)', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 3,
|
||||
name: 'ag-3',
|
||||
extra: {
|
||||
allow_overages: true,
|
||||
model_rate_limits: {
|
||||
'AICredits': {
|
||||
rate_limited_at: '2026-03-15T00:00:00Z',
|
||||
rate_limit_reset_at: '2099-03-15T00:00:00Z'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('account.creditsExhausted')
|
||||
})
|
||||
|
||||
it('模型限流 + overages 启用 + AICredits key 生效 → 普通限流样式(积分耗尽,无 ⚡)', () => {
|
||||
const wrapper = mount(AccountStatusIndicator, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 4,
|
||||
name: 'ag-4',
|
||||
extra: {
|
||||
allow_overages: true,
|
||||
model_rate_limits: {
|
||||
'claude-sonnet-4-5': {
|
||||
rate_limited_at: '2026-03-15T00:00:00Z',
|
||||
rate_limit_reset_at: '2099-03-15T00:00:00Z'
|
||||
},
|
||||
'AICredits': {
|
||||
rate_limited_at: '2026-03-15T00:00:00Z',
|
||||
rate_limit_reset_at: '2099-03-15T00:00:00Z'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Icon: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 模型限流 + 积分耗尽 → 不应显示 ⚡
|
||||
expect(wrapper.text()).toContain('CSon45')
|
||||
expect(wrapper.text()).not.toContain('⚡')
|
||||
// AICredits 积分耗尽状态应显示
|
||||
expect(wrapper.text()).toContain('account.creditsExhausted')
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,7 @@
|
||||
import { describe, expect, it, vi, beforeEach } from 'vitest'
|
||||
import { flushPromises, mount } from '@vue/test-utils'
|
||||
import AccountUsageCell from '../AccountUsageCell.vue'
|
||||
import type { Account } from '@/types'
|
||||
|
||||
const { getUsage } = vi.hoisted(() => ({
|
||||
getUsage: vi.fn()
|
||||
@@ -24,6 +25,35 @@ vi.mock('vue-i18n', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
function makeAccount(overrides: Partial<Account>): Account {
|
||||
return {
|
||||
id: 1,
|
||||
name: 'account',
|
||||
platform: 'antigravity',
|
||||
type: 'oauth',
|
||||
proxy_id: null,
|
||||
concurrency: 1,
|
||||
priority: 1,
|
||||
status: 'active',
|
||||
error_message: null,
|
||||
last_used_at: null,
|
||||
expires_at: null,
|
||||
auto_pause_on_expired: true,
|
||||
created_at: '2026-03-15T00:00:00Z',
|
||||
updated_at: '2026-03-15T00:00:00Z',
|
||||
schedulable: true,
|
||||
rate_limited_at: null,
|
||||
rate_limit_reset_at: null,
|
||||
overload_until: null,
|
||||
temp_unschedulable_until: null,
|
||||
temp_unschedulable_reason: null,
|
||||
session_window_start: null,
|
||||
session_window_end: null,
|
||||
session_window_status: null,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe('AccountUsageCell', () => {
|
||||
beforeEach(() => {
|
||||
getUsage.mockReset()
|
||||
@@ -49,12 +79,12 @@ describe('AccountUsageCell', () => {
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
account: makeAccount({
|
||||
id: 1001,
|
||||
platform: 'antigravity',
|
||||
type: 'oauth',
|
||||
extra: {}
|
||||
} as any
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
@@ -72,6 +102,40 @@ describe('AccountUsageCell', () => {
|
||||
expect(wrapper.text()).toContain('admin.accounts.usageWindow.gemini3Image|70|2026-03-01T09:00:00Z')
|
||||
})
|
||||
|
||||
it('Antigravity 会显示 AI Credits 余额信息', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
ai_credits: [
|
||||
{
|
||||
credit_type: 'GOOGLE_ONE_AI',
|
||||
amount: 25,
|
||||
minimum_balance: 5
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 1002,
|
||||
platform: 'antigravity',
|
||||
type: 'oauth',
|
||||
extra: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('admin.accounts.aiCreditsBalance')
|
||||
expect(wrapper.text()).toContain('25')
|
||||
})
|
||||
|
||||
|
||||
it('OpenAI OAuth 快照已过期时首屏会重新请求 usage', async () => {
|
||||
getUsage.mockResolvedValue({
|
||||
@@ -103,7 +167,7 @@ describe('AccountUsageCell', () => {
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
account: makeAccount({
|
||||
id: 2000,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
@@ -114,7 +178,7 @@ describe('AccountUsageCell', () => {
|
||||
codex_7d_used_percent: 34,
|
||||
codex_7d_reset_at: '2026-03-13T12:00:00Z'
|
||||
}
|
||||
} as any
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
@@ -137,7 +201,7 @@ describe('AccountUsageCell', () => {
|
||||
it('OpenAI OAuth 有现成快照且未限额时不会首屏请求 usage', async () => {
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
account: makeAccount({
|
||||
id: 2001,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
@@ -148,7 +212,7 @@ describe('AccountUsageCell', () => {
|
||||
codex_7d_used_percent: 34,
|
||||
codex_7d_reset_at: '2099-03-13T12:00:00Z'
|
||||
}
|
||||
} as any
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
@@ -196,15 +260,15 @@ describe('AccountUsageCell', () => {
|
||||
}
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2002,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
extra: {}
|
||||
} as any
|
||||
},
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 2002,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
extra: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
@@ -256,16 +320,16 @@ describe('AccountUsageCell', () => {
|
||||
seven_day: null
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2003,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
updated_at: '2026-03-07T10:00:00Z',
|
||||
extra: {}
|
||||
} as any
|
||||
},
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 2003,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
updated_at: '2026-03-07T10:00:00Z',
|
||||
extra: {}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
@@ -324,19 +388,19 @@ describe('AccountUsageCell', () => {
|
||||
}
|
||||
})
|
||||
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: {
|
||||
id: 2004,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
rate_limit_reset_at: '2099-03-07T12:00:00Z',
|
||||
extra: {
|
||||
codex_5h_used_percent: 0,
|
||||
codex_7d_used_percent: 0
|
||||
}
|
||||
} as any
|
||||
},
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 2004,
|
||||
platform: 'openai',
|
||||
type: 'oauth',
|
||||
rate_limit_reset_at: '2099-03-07T12:00:00Z',
|
||||
extra: {
|
||||
codex_5h_used_percent: 0,
|
||||
codex_7d_used_percent: 0
|
||||
}
|
||||
})
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: {
|
||||
|
||||
@@ -410,6 +410,18 @@
|
||||
|
||||
<!-- Model Distribution -->
|
||||
<ModelDistributionChart :model-stats="stats.models" :loading="false" />
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.inboundEndpoint')"
|
||||
/>
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.upstream_endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.upstreamEndpoint')"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<!-- No Data State -->
|
||||
@@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageStatsResponse } from '@/types'
|
||||
|
||||
@@ -24,7 +24,7 @@ const updateType = (value: string | number | boolean | null) => { emit('update:f
|
||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
||||
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
||||
</script>
|
||||
|
||||
@@ -35,6 +35,19 @@
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-endpoint="{ row }">
|
||||
<div class="max-w-[320px] space-y-1 text-xs">
|
||||
<div class="break-all text-gray-700 dark:text-gray-300">
|
||||
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.inbound') }}:</span>
|
||||
<span class="ml-1">{{ row.inbound_endpoint?.trim() || '-' }}</span>
|
||||
</div>
|
||||
<div class="break-all text-gray-700 dark:text-gray-300">
|
||||
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.upstream') }}:</span>
|
||||
<span class="ml-1">{{ row.upstream_endpoint?.trim() || '-' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template #cell-group="{ row }">
|
||||
<span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200">
|
||||
{{ row.group.name }}
|
||||
@@ -328,6 +341,7 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
|
||||
if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
|
||||
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
|
||||
}
|
||||
|
||||
const formatCacheTokens = (tokens: number): string => {
|
||||
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
|
||||
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
|
||||
|
||||
257
frontend/src/components/charts/EndpointDistributionChart.vue
Normal file
257
frontend/src/components/charts/EndpointDistributionChart.vue
Normal file
@@ -0,0 +1,257 @@
|
||||
<template>
|
||||
<div class="card p-4">
|
||||
<div class="mb-4 flex items-start justify-between gap-3">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white">
|
||||
{{ title || t('usage.endpointDistribution') }}
|
||||
</h3>
|
||||
<div class="flex flex-col items-end gap-2">
|
||||
<div
|
||||
v-if="showSourceToggle"
|
||||
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="source === 'inbound'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:source', 'inbound')"
|
||||
>
|
||||
{{ t('usage.inbound') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="source === 'upstream'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:source', 'upstream')"
|
||||
>
|
||||
{{ t('usage.upstream') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="source === 'path'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:source', 'path')"
|
||||
>
|
||||
{{ t('usage.path') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="showMetricToggle"
|
||||
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'tokens'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'tokens')"
|
||||
>
|
||||
{{ t('admin.dashboard.metricTokens') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
|
||||
:class="metric === 'actual_cost'
|
||||
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
|
||||
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
|
||||
@click="emit('update:metric', 'actual_cost')"
|
||||
>
|
||||
{{ t('admin.dashboard.metricActualCost') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="loading" class="flex h-48 items-center justify-center">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
<div v-else-if="displayEndpointStats.length > 0 && chartData" class="flex items-center gap-6">
|
||||
<div class="h-48 w-48">
|
||||
<Doughnut :data="chartData" :options="doughnutOptions" />
|
||||
</div>
|
||||
<div class="max-h-48 flex-1 overflow-y-auto">
|
||||
<table class="w-full text-xs">
|
||||
<thead>
|
||||
<tr class="text-gray-500 dark:text-gray-400">
|
||||
<th class="pb-2 text-left">{{ t('usage.endpoint') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.requests') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.tokens') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.actual') }}</th>
|
||||
<th class="pb-2 text-right">{{ t('admin.dashboard.standard') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="item in displayEndpointStats"
|
||||
:key="item.endpoint"
|
||||
class="border-t border-gray-100 dark:border-gray-700"
|
||||
>
|
||||
<td class="max-w-[180px] truncate py-1.5 font-medium text-gray-900 dark:text-white" :title="item.endpoint">
|
||||
{{ item.endpoint }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
|
||||
{{ formatNumber(item.requests) }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">
|
||||
{{ formatTokens(item.total_tokens) }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-green-600 dark:text-green-400">
|
||||
${{ formatCost(item.actual_cost) }}
|
||||
</td>
|
||||
<td class="py-1.5 text-right text-gray-400 dark:text-gray-500">
|
||||
${{ formatCost(item.cost) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="flex h-48 items-center justify-center text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { Chart as ChartJS, ArcElement, Tooltip, Legend } from 'chart.js'
|
||||
import { Doughnut } from 'vue-chartjs'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import type { EndpointStat } from '@/types'
|
||||
|
||||
ChartJS.register(ArcElement, Tooltip, Legend)
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||
type EndpointSource = 'inbound' | 'upstream' | 'path'
|
||||
|
||||
const props = withDefaults(
|
||||
defineProps<{
|
||||
endpointStats: EndpointStat[]
|
||||
upstreamEndpointStats?: EndpointStat[]
|
||||
endpointPathStats?: EndpointStat[]
|
||||
loading?: boolean
|
||||
title?: string
|
||||
metric?: DistributionMetric
|
||||
source?: EndpointSource
|
||||
showMetricToggle?: boolean
|
||||
showSourceToggle?: boolean
|
||||
}>(),
|
||||
{
|
||||
upstreamEndpointStats: () => [],
|
||||
endpointPathStats: () => [],
|
||||
loading: false,
|
||||
title: '',
|
||||
metric: 'tokens',
|
||||
source: 'inbound',
|
||||
showMetricToggle: false,
|
||||
showSourceToggle: false
|
||||
}
|
||||
)
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:metric': [value: DistributionMetric]
|
||||
'update:source': [value: EndpointSource]
|
||||
}>()
|
||||
|
||||
const chartColors = [
|
||||
'#3b82f6',
|
||||
'#10b981',
|
||||
'#f59e0b',
|
||||
'#ef4444',
|
||||
'#8b5cf6',
|
||||
'#ec4899',
|
||||
'#14b8a6',
|
||||
'#f97316',
|
||||
'#6366f1',
|
||||
'#84cc16',
|
||||
'#06b6d4',
|
||||
'#a855f7'
|
||||
]
|
||||
|
||||
const displayEndpointStats = computed(() => {
|
||||
const sourceStats = props.source === 'upstream'
|
||||
? props.upstreamEndpointStats
|
||||
: props.source === 'path'
|
||||
? props.endpointPathStats
|
||||
: props.endpointStats
|
||||
if (!sourceStats?.length) return []
|
||||
|
||||
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
|
||||
return [...sourceStats].sort((a, b) => b[metricKey] - a[metricKey])
|
||||
})
|
||||
|
||||
const chartData = computed(() => {
|
||||
if (!displayEndpointStats.value?.length) return null
|
||||
|
||||
return {
|
||||
labels: displayEndpointStats.value.map((item) => item.endpoint),
|
||||
datasets: [
|
||||
{
|
||||
data: displayEndpointStats.value.map((item) =>
|
||||
props.metric === 'actual_cost' ? item.actual_cost : item.total_tokens
|
||||
),
|
||||
backgroundColor: chartColors.slice(0, displayEndpointStats.value.length),
|
||||
borderWidth: 0
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
const doughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
display: false
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
const value = context.raw as number
|
||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||
const percentage = total > 0 ? ((value / total) * 100).toFixed(1) : '0.0'
|
||||
const formattedValue = props.metric === 'actual_cost'
|
||||
? `$${formatCost(value)}`
|
||||
: formatTokens(value)
|
||||
return `${context.label}: ${formattedValue} (${percentage}%)`
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
const formatTokens = (value: number): string => {
|
||||
if (value >= 1_000_000_000) {
|
||||
return `${(value / 1_000_000_000).toFixed(2)}B`
|
||||
} else if (value >= 1_000_000) {
|
||||
return `${(value / 1_000_000).toFixed(2)}M`
|
||||
} else if (value >= 1_000) {
|
||||
return `${(value / 1_000).toFixed(2)}K`
|
||||
}
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatNumber = (value: number): string => {
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
} else if (value >= 1) {
|
||||
return value.toFixed(2)
|
||||
} else if (value >= 0.01) {
|
||||
return value.toFixed(3)
|
||||
}
|
||||
return value.toFixed(4)
|
||||
}
|
||||
</script>
|
||||
@@ -127,7 +127,7 @@
|
||||
>
|
||||
{{ t('admin.dashboard.failedToLoad') }}
|
||||
</div>
|
||||
<div v-else-if="rankingItems.length > 0 && rankingChartData" class="flex items-center gap-6">
|
||||
<div v-else-if="rankingDisplayItems.length > 0 && rankingChartData" class="flex items-center gap-6">
|
||||
<div class="h-48 w-48">
|
||||
<Doughnut :data="rankingChartData" :options="rankingDoughnutOptions" />
|
||||
</div>
|
||||
@@ -143,21 +143,24 @@
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="(item, index) in rankingItems"
|
||||
:key="`${item.user_id}-${index}`"
|
||||
class="cursor-pointer border-t border-gray-100 transition-colors hover:bg-gray-50 dark:border-gray-700 dark:hover:bg-dark-700/40"
|
||||
@click="emit('ranking-click', item)"
|
||||
v-for="(item, index) in rankingDisplayItems"
|
||||
:key="item.isOther ? 'others' : `${item.user_id}-${index}`"
|
||||
class="border-t border-gray-100 transition-colors dark:border-gray-700"
|
||||
:class="item.isOther
|
||||
? 'bg-gray-50/70 dark:bg-dark-700/20'
|
||||
: 'cursor-pointer hover:bg-gray-50 dark:hover:bg-dark-700/40'"
|
||||
@click="item.isOther ? undefined : emit('ranking-click', item)"
|
||||
>
|
||||
<td class="py-1.5">
|
||||
<div class="flex min-w-0 items-center gap-2">
|
||||
<span class="shrink-0 text-[11px] font-semibold text-gray-500 dark:text-gray-400">
|
||||
#{{ index + 1 }}
|
||||
{{ item.isOther ? 'Σ' : `#${index + 1}` }}
|
||||
</span>
|
||||
<span
|
||||
class="block max-w-[140px] truncate font-medium text-gray-900 dark:text-white"
|
||||
:title="getRankingUserLabel(item)"
|
||||
:title="getRankingRowLabel(item)"
|
||||
>
|
||||
{{ getRankingUserLabel(item) }}
|
||||
{{ getRankingRowLabel(item) }}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
@@ -197,11 +200,14 @@ ChartJS.register(ArcElement, Tooltip, Legend)
|
||||
const { t } = useI18n()
|
||||
|
||||
type DistributionMetric = 'tokens' | 'actual_cost'
|
||||
type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean }
|
||||
const props = withDefaults(defineProps<{
|
||||
modelStats: ModelStat[]
|
||||
enableRankingView?: boolean
|
||||
rankingItems?: UserSpendingRankingItem[]
|
||||
rankingTotalActualCost?: number
|
||||
rankingTotalRequests?: number
|
||||
rankingTotalTokens?: number
|
||||
loading?: boolean
|
||||
metric?: DistributionMetric
|
||||
showMetricToggle?: boolean
|
||||
@@ -211,6 +217,8 @@ const props = withDefaults(defineProps<{
|
||||
enableRankingView: false,
|
||||
rankingItems: () => [],
|
||||
rankingTotalActualCost: 0,
|
||||
rankingTotalRequests: 0,
|
||||
rankingTotalTokens: 0,
|
||||
loading: false,
|
||||
metric: 'tokens',
|
||||
showMetricToggle: false,
|
||||
@@ -266,14 +274,14 @@ const chartData = computed(() => {
|
||||
const rankingChartData = computed(() => {
|
||||
if (!props.rankingItems?.length) return null
|
||||
|
||||
const rankedTotal = props.rankingItems.reduce((sum, item) => sum + item.actual_cost, 0)
|
||||
const otherActualCost = Math.max((props.rankingTotalActualCost || 0) - rankedTotal, 0)
|
||||
const labels = props.rankingItems.map((item, index) => `#${index + 1} ${getRankingUserLabel(item)}`)
|
||||
const data = props.rankingItems.map((item) => item.actual_cost)
|
||||
const backgroundColor = chartColors.slice(0, props.rankingItems.length)
|
||||
|
||||
if (otherActualCost > 0.000001) {
|
||||
if (otherRankingItem.value) {
|
||||
labels.push(t('admin.dashboard.spendingRankingOther'))
|
||||
data.push(otherActualCost)
|
||||
data.push(otherRankingItem.value.actual_cost)
|
||||
backgroundColor.push('#94a3b8')
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -281,13 +289,43 @@ const rankingChartData = computed(() => {
|
||||
datasets: [
|
||||
{
|
||||
data,
|
||||
backgroundColor: chartColors.slice(0, data.length),
|
||||
backgroundColor,
|
||||
borderWidth: 0
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
const otherRankingItem = computed<RankingDisplayItem | null>(() => {
|
||||
if (!props.rankingItems?.length) return null
|
||||
|
||||
const rankedActualCost = props.rankingItems.reduce((sum, item) => sum + item.actual_cost, 0)
|
||||
const rankedRequests = props.rankingItems.reduce((sum, item) => sum + item.requests, 0)
|
||||
const rankedTokens = props.rankingItems.reduce((sum, item) => sum + item.tokens, 0)
|
||||
|
||||
const otherActualCost = Math.max((props.rankingTotalActualCost || 0) - rankedActualCost, 0)
|
||||
const otherRequests = Math.max((props.rankingTotalRequests || 0) - rankedRequests, 0)
|
||||
const otherTokens = Math.max((props.rankingTotalTokens || 0) - rankedTokens, 0)
|
||||
|
||||
if (otherActualCost <= 0.000001 && otherRequests <= 0 && otherTokens <= 0) return null
|
||||
|
||||
return {
|
||||
user_id: 0,
|
||||
email: '',
|
||||
actual_cost: otherActualCost,
|
||||
requests: otherRequests,
|
||||
tokens: otherTokens,
|
||||
isOther: true
|
||||
}
|
||||
})
|
||||
|
||||
const rankingDisplayItems = computed<RankingDisplayItem[]>(() => {
|
||||
if (!props.rankingItems?.length) return []
|
||||
return otherRankingItem.value
|
||||
? [...props.rankingItems, otherRankingItem.value]
|
||||
: [...props.rankingItems]
|
||||
})
|
||||
|
||||
const doughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
@@ -351,6 +389,11 @@ const getRankingUserLabel = (item: UserSpendingRankingItem): string => {
|
||||
return t('admin.redeem.userPrefix', { id: item.user_id })
|
||||
}
|
||||
|
||||
const getRankingRowLabel = (item: RankingDisplayItem): string => {
|
||||
if (item.isOther) return t('admin.dashboard.spendingRankingOther')
|
||||
return getRankingUserLabel(item)
|
||||
}
|
||||
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
|
||||
@@ -5,6 +5,14 @@ import ModelDistributionChart from '../ModelDistributionChart.vue'
|
||||
|
||||
const messages: Record<string, string> = {
|
||||
'admin.dashboard.modelDistribution': 'Model Distribution',
|
||||
'admin.dashboard.spendingRankingTitle': 'User Spending Ranking',
|
||||
'admin.dashboard.viewModelDistribution': 'Model Distribution',
|
||||
'admin.dashboard.viewSpendingRanking': 'User Spending Ranking',
|
||||
'admin.dashboard.spendingRankingUser': 'User',
|
||||
'admin.dashboard.spendingRankingRequests': 'Requests',
|
||||
'admin.dashboard.spendingRankingTokens': 'Tokens',
|
||||
'admin.dashboard.spendingRankingSpend': 'Spend',
|
||||
'admin.dashboard.spendingRankingOther': 'Others',
|
||||
'admin.dashboard.model': 'Model',
|
||||
'admin.dashboard.requests': 'Requests',
|
||||
'admin.dashboard.tokens': 'Tokens',
|
||||
@@ -13,6 +21,7 @@ const messages: Record<string, string> = {
|
||||
'admin.dashboard.metricTokens': 'By Tokens',
|
||||
'admin.dashboard.metricActualCost': 'By Actual Cost',
|
||||
'admin.dashboard.noDataAvailable': 'No data available',
|
||||
'admin.redeem.userPrefix': 'User #{id}',
|
||||
}
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
@@ -116,4 +125,47 @@ describe('ModelDistributionChart', () => {
|
||||
})
|
||||
expect(label).toBe('model-b: $1.40 (87.5%)')
|
||||
})
|
||||
|
||||
it('renders Others in the spending ranking table and uses a dedicated chart color', async () => {
|
||||
const wrapper = mount(ModelDistributionChart, {
|
||||
props: {
|
||||
modelStats: [],
|
||||
enableRankingView: true,
|
||||
rankingItems: [
|
||||
{ user_id: 1, email: 'alpha@example.com', actual_cost: 12, requests: 10, tokens: 1000 },
|
||||
{ user_id: 2, email: 'beta@example.com', actual_cost: 8, requests: 6, tokens: 600 },
|
||||
],
|
||||
rankingTotalActualCost: 30,
|
||||
rankingTotalRequests: 20,
|
||||
rankingTotalTokens: 2000,
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
LoadingSpinner: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const rankingButton = wrapper.findAll('button').find((button) => button.text() === 'User Spending Ranking')
|
||||
expect(rankingButton).toBeTruthy()
|
||||
await rankingButton!.trigger('click')
|
||||
|
||||
const chartData = JSON.parse(wrapper.find('.chart-data').text())
|
||||
expect(chartData.labels).toEqual([
|
||||
'#1 alpha@example.com',
|
||||
'#2 beta@example.com',
|
||||
'Others',
|
||||
])
|
||||
expect(chartData.datasets[0].data).toEqual([12, 8, 10])
|
||||
expect(chartData.datasets[0].backgroundColor[0]).toBe('#3b82f6')
|
||||
expect(chartData.datasets[0].backgroundColor[2]).toBe('#94a3b8')
|
||||
expect(chartData.datasets[0].backgroundColor[2]).not.toBe(chartData.datasets[0].backgroundColor[0])
|
||||
|
||||
const rows = wrapper.findAll('tbody tr')
|
||||
expect(rows).toHaveLength(3)
|
||||
expect(rows[2].text()).toContain('Others')
|
||||
expect(rows[2].text()).toContain('4')
|
||||
expect(rows[2].text()).toContain('400')
|
||||
expect(rows[2].text()).toContain('$10.00')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -245,6 +245,7 @@ export default {
|
||||
// Common
|
||||
common: {
|
||||
loading: 'Loading...',
|
||||
justNow: 'just now',
|
||||
save: 'Save',
|
||||
cancel: 'Cancel',
|
||||
delete: 'Delete',
|
||||
@@ -718,6 +719,13 @@ export default {
|
||||
preparingExport: 'Preparing export...',
|
||||
model: 'Model',
|
||||
reasoningEffort: 'Reasoning Effort',
|
||||
endpoint: 'Endpoint',
|
||||
endpointDistribution: 'Endpoint Distribution',
|
||||
inbound: 'Inbound',
|
||||
upstream: 'Upstream',
|
||||
path: 'Path',
|
||||
inboundEndpoint: 'Inbound Endpoint',
|
||||
upstreamEndpoint: 'Upstream Endpoint',
|
||||
type: 'Type',
|
||||
tokens: 'Tokens',
|
||||
cost: 'Cost',
|
||||
@@ -1648,6 +1656,14 @@ export default {
|
||||
enabled: 'Enabled',
|
||||
disabled: 'Disabled'
|
||||
},
|
||||
claudeMaxSimulation: {
|
||||
title: 'Claude Max Usage Simulation',
|
||||
tooltip:
|
||||
'When enabled, for Claude models without upstream cache-write usage, the system deterministically maps tokens to a small input plus 1h cache creation while keeping total tokens unchanged.',
|
||||
enabled: 'Enabled (simulate 1h cache)',
|
||||
disabled: 'Disabled',
|
||||
hint: 'Only token categories in usage billing logs are adjusted. No per-request mapping state is persisted.'
|
||||
},
|
||||
supportedScopes: {
|
||||
title: 'Supported Model Families',
|
||||
tooltip: 'Select the model families this group supports. Unchecked families will not be routed to this group.',
|
||||
@@ -1860,6 +1876,9 @@ export default {
|
||||
rateLimitedUntil: 'Rate limited and removed from scheduling. Auto resumes at {time}',
|
||||
rateLimitedAutoResume: 'Auto resumes in {time}',
|
||||
modelRateLimitedUntil: '{model} rate limited until {time}',
|
||||
modelCreditOveragesUntil: '{model} using AI Credits until {time}',
|
||||
creditsExhausted: 'Credits Exhausted',
|
||||
creditsExhaustedUntil: 'AI Credits exhausted, expected recovery at {time}',
|
||||
overloadedUntil: 'Overloaded until {time}',
|
||||
viewTempUnschedDetails: 'View temp unschedulable details'
|
||||
},
|
||||
@@ -1961,7 +1980,7 @@ export default {
|
||||
resetQuota: 'Reset Quota',
|
||||
quotaLimit: 'Quota Limit',
|
||||
quotaLimitPlaceholder: '0 means unlimited',
|
||||
quotaLimitHint: 'Set daily/weekly/total spending limits (USD). Account will be paused when any limit is reached. Changing limits won\'t reset usage.',
|
||||
quotaLimitHint: 'Set daily/weekly/total spending limits (USD). Anthropic API key accounts can also configure client affinity. Changing limits won\'t reset usage.',
|
||||
quotaLimitToggle: 'Enable Quota Limit',
|
||||
quotaLimitToggleHint: 'When enabled, account will be paused when usage reaches the set limit',
|
||||
quotaDailyLimit: 'Daily Limit',
|
||||
@@ -2158,7 +2177,7 @@ export default {
|
||||
// Quota control (Anthropic OAuth/SetupToken only)
|
||||
quotaControl: {
|
||||
title: 'Quota Control',
|
||||
hint: 'Only applies to Anthropic OAuth/Setup Token accounts',
|
||||
hint: 'Configure cost window, session limits, client affinity and other scheduling controls.',
|
||||
windowCost: {
|
||||
label: '5h Window Cost Limit',
|
||||
hint: 'Limit account cost usage within the 5-hour window',
|
||||
@@ -2213,8 +2232,26 @@ export default {
|
||||
hint: 'Force all cache creation tokens to be billed as the selected TTL tier (5m or 1h)',
|
||||
target: 'Target TTL',
|
||||
targetHint: 'Select the TTL tier for billing'
|
||||
},
|
||||
clientAffinity: {
|
||||
label: 'Client Affinity Scheduling',
|
||||
hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching'
|
||||
}
|
||||
},
|
||||
affinityNoClients: 'No affinity clients',
|
||||
affinityClients: '{count} affinity clients:',
|
||||
affinitySection: 'Client Affinity',
|
||||
affinitySectionHint: 'Control how clients are distributed across accounts. Configure zone thresholds to balance load.',
|
||||
affinityToggle: 'Enable Client Affinity',
|
||||
affinityToggleHint: 'New sessions prefer accounts previously used by this client',
|
||||
affinityBase: 'Base Limit (Green Zone)',
|
||||
affinityBasePlaceholder: 'Empty = no limit',
|
||||
affinityBaseHint: 'Max clients in green zone (full priority scheduling)',
|
||||
affinityBaseOffHint: 'No green zone limit. All clients receive full priority scheduling.',
|
||||
affinityBuffer: 'Buffer (Yellow Zone)',
|
||||
affinityBufferPlaceholder: 'e.g. 3',
|
||||
affinityBufferHint: 'Additional clients allowed in the yellow zone (degraded priority)',
|
||||
affinityBufferInfinite: 'Unlimited',
|
||||
expired: 'Expired',
|
||||
proxy: 'Proxy',
|
||||
noProxy: 'No Proxy',
|
||||
@@ -2232,6 +2269,10 @@ export default {
|
||||
mixedSchedulingHint: 'Enable to participate in Anthropic/Gemini group scheduling',
|
||||
mixedSchedulingTooltip:
|
||||
'!! WARNING !! Antigravity Claude and Anthropic Claude cannot be used in the same context. If you have both Anthropic and Antigravity accounts, enabling this option will cause frequent 400 errors. When enabled, please use the group feature to isolate Antigravity accounts from Anthropic accounts. Make sure you understand this before enabling!!',
|
||||
aiCreditsBalance: 'AI Credits',
|
||||
allowOverages: 'Allow Overages (AI Credits)',
|
||||
allowOveragesTooltip:
|
||||
'Only use AI Credits after free quota is explicitly exhausted. Ordinary concurrent 429 rate limits will not switch to overages.',
|
||||
creating: 'Creating...',
|
||||
updating: 'Updating...',
|
||||
accountCreated: 'Account created successfully',
|
||||
@@ -2665,7 +2706,7 @@ export default {
|
||||
geminiFlashDaily: 'Flash',
|
||||
gemini3Pro: 'G3P',
|
||||
gemini3Flash: 'G3F',
|
||||
gemini3Image: 'GImage',
|
||||
gemini3Image: 'G31FI',
|
||||
claude: 'Claude'
|
||||
},
|
||||
tier: {
|
||||
@@ -3835,6 +3876,8 @@ export default {
|
||||
ignoreNoAvailableAccountsHint: 'When enabled, "No available accounts" errors will not be written to the error log (not recommended; usually a config issue).',
|
||||
ignoreInvalidApiKeyErrors: 'Ignore invalid API key errors',
|
||||
ignoreInvalidApiKeyErrorsHint: 'When enabled, invalid or missing API key errors (INVALID_API_KEY, API_KEY_REQUIRED) will not be written to the error log.',
|
||||
ignoreInsufficientBalanceErrors: 'Ignore Insufficient Balance Errors',
|
||||
ignoreInsufficientBalanceErrorsHint: 'When enabled, insufficient account balance errors will not be written to the error log.',
|
||||
autoRefresh: 'Auto Refresh',
|
||||
enableAutoRefresh: 'Enable auto refresh',
|
||||
enableAutoRefreshHint: 'Automatically refresh dashboard data at a fixed interval.',
|
||||
@@ -3959,6 +4002,9 @@ export default {
|
||||
invitationCodeHint: 'When enabled, users must enter a valid invitation code to register',
|
||||
passwordReset: 'Password Reset',
|
||||
passwordResetHint: 'Allow users to reset their password via email',
|
||||
frontendUrl: 'Frontend URL',
|
||||
frontendUrlPlaceholder: 'https://example.com',
|
||||
frontendUrlHint: 'Used to generate password reset links in emails. Example: https://example.com',
|
||||
totp: 'Two-Factor Authentication (2FA)',
|
||||
totpHint: 'Allow users to use authenticator apps like Google Authenticator',
|
||||
totpKeyNotConfigured:
|
||||
@@ -4173,40 +4219,55 @@ export default {
|
||||
usage: 'Usage: Add to request header - x-api-key: <your-admin-api-key>'
|
||||
},
|
||||
soraS3: {
|
||||
title: 'Sora S3 Storage',
|
||||
description: 'Manage multiple Sora S3 endpoints and switch the active profile',
|
||||
title: 'Sora Storage',
|
||||
description: 'Manage Sora media storage profiles with S3 and Google Drive support',
|
||||
newProfile: 'New Profile',
|
||||
reloadProfiles: 'Reload Profiles',
|
||||
empty: 'No Sora S3 profiles yet, create one first',
|
||||
createTitle: 'Create Sora S3 Profile',
|
||||
editTitle: 'Edit Sora S3 Profile',
|
||||
empty: 'No storage profiles yet, create one first',
|
||||
createTitle: 'Create Storage Profile',
|
||||
editTitle: 'Edit Storage Profile',
|
||||
selectProvider: 'Select Storage Type',
|
||||
providerS3Desc: 'S3-compatible object storage',
|
||||
providerGDriveDesc: 'Google Drive cloud storage',
|
||||
profileID: 'Profile ID',
|
||||
profileName: 'Profile Name',
|
||||
setActive: 'Set as active after creation',
|
||||
saveProfile: 'Save Profile',
|
||||
activateProfile: 'Activate',
|
||||
profileCreated: 'Sora S3 profile created',
|
||||
profileSaved: 'Sora S3 profile saved',
|
||||
profileDeleted: 'Sora S3 profile deleted',
|
||||
profileActivated: 'Sora S3 active profile switched',
|
||||
profileCreated: 'Storage profile created',
|
||||
profileSaved: 'Storage profile saved',
|
||||
profileDeleted: 'Storage profile deleted',
|
||||
profileActivated: 'Active storage profile switched',
|
||||
profileIDRequired: 'Profile ID is required',
|
||||
profileNameRequired: 'Profile name is required',
|
||||
profileSelectRequired: 'Please select a profile first',
|
||||
endpointRequired: 'S3 endpoint is required when enabled',
|
||||
bucketRequired: 'Bucket is required when enabled',
|
||||
accessKeyRequired: 'Access Key ID is required when enabled',
|
||||
deleteConfirm: 'Delete Sora S3 profile {profileID}?',
|
||||
deleteConfirm: 'Delete storage profile {profileID}?',
|
||||
columns: {
|
||||
profile: 'Profile',
|
||||
profileId: 'Profile ID',
|
||||
name: 'Name',
|
||||
provider: 'Type',
|
||||
active: 'Active',
|
||||
endpoint: 'Endpoint',
|
||||
bucket: 'Bucket',
|
||||
storagePath: 'Storage Path',
|
||||
capacityUsage: 'Capacity / Used',
|
||||
capacityUnlimited: 'Unlimited',
|
||||
videoCount: 'Videos',
|
||||
videoCompleted: 'completed',
|
||||
videoInProgress: 'in progress',
|
||||
quota: 'Default Quota',
|
||||
updatedAt: 'Updated At',
|
||||
actions: 'Actions'
|
||||
actions: 'Actions',
|
||||
rootFolder: 'Root folder',
|
||||
testInTable: 'Test',
|
||||
testingInTable: 'Testing...',
|
||||
testTimeout: 'Test timed out (15s)'
|
||||
},
|
||||
enabled: 'Enable S3 Storage',
|
||||
enabledHint: 'When enabled, Sora generated media files will be automatically uploaded to S3 storage',
|
||||
enabled: 'Enable Storage',
|
||||
enabledHint: 'When enabled, Sora generated media files will be automatically uploaded',
|
||||
endpoint: 'S3 Endpoint',
|
||||
region: 'Region',
|
||||
bucket: 'Bucket',
|
||||
@@ -4215,16 +4276,38 @@ export default {
|
||||
secretAccessKey: 'Secret Access Key',
|
||||
secretConfigured: '(Configured, leave blank to keep)',
|
||||
cdnUrl: 'CDN URL',
|
||||
cdnUrlHint: 'Optional. When configured, files are accessed via CDN URL instead of presigned URLs',
|
||||
cdnUrlHint: 'Optional. When configured, files are accessed via CDN URL',
|
||||
forcePathStyle: 'Force Path Style',
|
||||
defaultQuota: 'Default Storage Quota',
|
||||
defaultQuotaHint: 'Default quota when not specified at user or group level. 0 means unlimited',
|
||||
testConnection: 'Test Connection',
|
||||
testing: 'Testing...',
|
||||
testSuccess: 'S3 connection test successful',
|
||||
testFailed: 'S3 connection test failed',
|
||||
saved: 'Sora S3 settings saved successfully',
|
||||
saveFailed: 'Failed to save Sora S3 settings'
|
||||
testSuccess: 'Connection test successful',
|
||||
testFailed: 'Connection test failed',
|
||||
saved: 'Storage settings saved successfully',
|
||||
saveFailed: 'Failed to save storage settings',
|
||||
gdrive: {
|
||||
authType: 'Authentication Method',
|
||||
serviceAccount: 'Service Account',
|
||||
clientId: 'Client ID',
|
||||
clientSecret: 'Client Secret',
|
||||
clientSecretConfigured: '(Configured, leave blank to keep)',
|
||||
refreshToken: 'Refresh Token',
|
||||
refreshTokenConfigured: '(Configured, leave blank to keep)',
|
||||
serviceAccountJson: 'Service Account JSON',
|
||||
serviceAccountConfigured: '(Configured, leave blank to keep)',
|
||||
folderId: 'Folder ID (optional)',
|
||||
authorize: 'Authorize Google Drive',
|
||||
authorizeHint: 'Get Refresh Token via OAuth2',
|
||||
oauthFieldsRequired: 'Please fill in Client ID and Client Secret first',
|
||||
oauthSuccess: 'Google Drive authorization successful',
|
||||
oauthFailed: 'Google Drive authorization failed',
|
||||
closeWindow: 'This window will close automatically',
|
||||
processing: 'Processing authorization...',
|
||||
testStorage: 'Test Storage',
|
||||
testSuccess: 'Google Drive storage test passed (upload, access, delete all OK)',
|
||||
testFailed: 'Google Drive storage test failed'
|
||||
}
|
||||
},
|
||||
streamTimeout: {
|
||||
title: 'Stream Timeout Handling',
|
||||
@@ -4695,6 +4778,7 @@ export default {
|
||||
downloadLocal: 'Download',
|
||||
canDownload: 'to download',
|
||||
regenrate: 'Regenerate',
|
||||
regenerate: 'Regenerate',
|
||||
creatorPlaceholder: 'Describe the video or image you want to create...',
|
||||
videoModels: 'Video Models',
|
||||
imageModels: 'Image Models',
|
||||
@@ -4711,6 +4795,13 @@ export default {
|
||||
galleryEmptyTitle: 'No works yet',
|
||||
galleryEmptyDesc: 'Your creations will be displayed here. Go to the generate page to start your first creation.',
|
||||
startCreating: 'Start Creating',
|
||||
yesterday: 'Yesterday'
|
||||
yesterday: 'Yesterday',
|
||||
landscape: 'Landscape',
|
||||
portrait: 'Portrait',
|
||||
square: 'Square',
|
||||
examplePrompt1: 'A golden Shiba Inu walking through the streets of Shibuya, Tokyo, camera following, cinematic shot, 4K',
|
||||
examplePrompt2: 'Drone aerial view, green aurora reflecting on a glacial lake in Iceland, slow push-in',
|
||||
examplePrompt3: 'Cyberpunk futuristic city, neon lights reflected in rain puddles, nightscape, cinematic colors',
|
||||
examplePrompt4: 'Chinese ink painting style, a small boat drifting among misty mountains and rivers, classical atmosphere'
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,6 +245,7 @@ export default {
|
||||
// Common
|
||||
common: {
|
||||
loading: '加载中...',
|
||||
justNow: '刚刚',
|
||||
save: '保存',
|
||||
cancel: '取消',
|
||||
delete: '删除',
|
||||
@@ -723,6 +724,13 @@ export default {
|
||||
preparingExport: '正在准备导出...',
|
||||
model: '模型',
|
||||
reasoningEffort: '推理强度',
|
||||
endpoint: '端点',
|
||||
endpointDistribution: '端点分布',
|
||||
inbound: '入站',
|
||||
upstream: '上游',
|
||||
path: '路径',
|
||||
inboundEndpoint: '入站端点',
|
||||
upstreamEndpoint: '上游端点',
|
||||
type: '类型',
|
||||
tokens: 'Token',
|
||||
cost: '费用',
|
||||
@@ -1967,7 +1975,7 @@ export default {
|
||||
resetQuota: '重置配额',
|
||||
quotaLimit: '配额限制',
|
||||
quotaLimitPlaceholder: '0 表示不限制',
|
||||
quotaLimitHint: '设置日/周/总使用额度(美元),任一维度达到限额后账号暂停调度。修改限额不会重置已用额度。',
|
||||
quotaLimitHint: '设置日/周/总使用额度(美元),任一维度达到限额后账号暂停调度。Anthropic API Key 账号还可配置客户端亲和。修改限额不会重置已用额度。',
|
||||
quotaLimitToggle: '启用配额限制',
|
||||
quotaLimitToggleHint: '开启后,当账号用量达到设定额度时自动暂停调度',
|
||||
quotaDailyLimit: '日限额',
|
||||
@@ -2045,6 +2053,9 @@ export default {
|
||||
rateLimitedUntil: '限流中,当前不参与调度,预计 {time} 自动恢复',
|
||||
rateLimitedAutoResume: '{time} 自动恢复',
|
||||
modelRateLimitedUntil: '{model} 限流至 {time}',
|
||||
modelCreditOveragesUntil: '{model} 正在使用 AI Credits,至 {time}',
|
||||
creditsExhausted: '积分已用尽',
|
||||
creditsExhaustedUntil: 'AI Credits 已用尽,预计 {time} 恢复',
|
||||
overloadedUntil: '负载过重,重置时间:{time}',
|
||||
viewTempUnschedDetails: '查看临时不可调度详情'
|
||||
},
|
||||
@@ -2098,7 +2109,7 @@ export default {
|
||||
geminiFlashDaily: 'Flash',
|
||||
gemini3Pro: 'G3P',
|
||||
gemini3Flash: 'G3F',
|
||||
gemini3Image: 'GImage',
|
||||
gemini3Image: 'G31FI',
|
||||
claude: 'Claude'
|
||||
},
|
||||
tier: {
|
||||
@@ -2308,7 +2319,7 @@ export default {
|
||||
// Quota control (Anthropic OAuth/SetupToken only)
|
||||
quotaControl: {
|
||||
title: '配额控制',
|
||||
hint: '仅适用于 Anthropic OAuth/Setup Token 账号',
|
||||
hint: '配置费用窗口、会话限制、客户端亲和等调度控制。',
|
||||
windowCost: {
|
||||
label: '5h窗口费用控制',
|
||||
hint: '限制账号在5小时窗口内的费用使用',
|
||||
@@ -2363,8 +2374,26 @@ export default {
|
||||
hint: '将所有缓存创建 token 强制按指定的 TTL 类型(5分钟或1小时)计费',
|
||||
target: '目标 TTL',
|
||||
targetHint: '选择计费使用的 TTL 类型'
|
||||
},
|
||||
clientAffinity: {
|
||||
label: '客户端亲和调度',
|
||||
hint: '启用后,新会话会优先调度到该客户端之前使用过的账号,避免频繁切换账号'
|
||||
}
|
||||
},
|
||||
affinityNoClients: '无亲和客户端',
|
||||
affinityClients: '{count} 个亲和客户端:',
|
||||
affinitySection: '客户端亲和',
|
||||
affinitySectionHint: '控制客户端在账号间的分布。通过配置区域阈值来平衡负载。',
|
||||
affinityToggle: '启用客户端亲和',
|
||||
affinityToggleHint: '新会话优先调度到该客户端之前使用过的账号',
|
||||
affinityBase: '基础限额(绿区)',
|
||||
affinityBasePlaceholder: '留空表示不限制',
|
||||
affinityBaseHint: '绿区最大客户端数量(完整优先级调度)',
|
||||
affinityBaseOffHint: '未开启绿区限制,所有客户端均享受完整优先级调度',
|
||||
affinityBuffer: '缓冲区(黄区)',
|
||||
affinityBufferPlaceholder: '例如 3',
|
||||
affinityBufferHint: '黄区允许的额外客户端数量(降级优先级调度)',
|
||||
affinityBufferInfinite: '不限制',
|
||||
expired: '已过期',
|
||||
proxy: '代理',
|
||||
noProxy: '无代理',
|
||||
@@ -2382,6 +2411,10 @@ export default {
|
||||
mixedSchedulingHint: '启用后可参与 Anthropic/Gemini 分组的调度',
|
||||
mixedSchedulingTooltip:
|
||||
'!!注意!! Antigravity Claude 和 Anthropic Claude 无法在同个上下文中使用,如果你同时有 Anthropic 账号和 Antigravity 账号,开启此选项会导致经常 400 报错。开启后,请用分组功能做好 Antigravity 账号和 Anthropic 账号的隔离。一定要弄明白再开启!!',
|
||||
aiCreditsBalance: 'AI Credits',
|
||||
allowOverages: '允许超量请求 (AI Credits)',
|
||||
allowOveragesTooltip:
|
||||
'仅在免费配额被明确判定为耗尽后才会使用 AI Credits。普通并发 429 限流不会切换到超量请求。',
|
||||
creating: '创建中...',
|
||||
updating: '更新中...',
|
||||
accountCreated: '账号创建成功',
|
||||
@@ -4009,6 +4042,8 @@ export default {
|
||||
ignoreNoAvailableAccountsHint: '启用后,"No available accounts" 错误将不会写入错误日志(不推荐,这通常是配置问题)。',
|
||||
ignoreInvalidApiKeyErrors: '忽略无效 API Key 错误',
|
||||
ignoreInvalidApiKeyErrorsHint: '启用后,无效或缺失 API Key 的错误(INVALID_API_KEY、API_KEY_REQUIRED)将不会写入错误日志。',
|
||||
ignoreInsufficientBalanceErrors: '忽略余额不足错误',
|
||||
ignoreInsufficientBalanceErrorsHint: '启用后,账号余额不足(Insufficient balance)的错误将不会写入错误日志。',
|
||||
autoRefresh: '自动刷新',
|
||||
enableAutoRefresh: '启用自动刷新',
|
||||
enableAutoRefreshHint: '自动刷新仪表板数据,启用后会定期拉取最新数据。',
|
||||
@@ -4133,6 +4168,9 @@ export default {
|
||||
invitationCodeHint: '开启后,用户注册时需要填写有效的邀请码',
|
||||
passwordReset: '忘记密码',
|
||||
passwordResetHint: '允许用户通过邮箱重置密码',
|
||||
frontendUrl: '前端地址',
|
||||
frontendUrlPlaceholder: 'https://example.com',
|
||||
frontendUrlHint: '用于生成邮件中的密码重置链接,例如 https://example.com',
|
||||
totp: '双因素认证 (2FA)',
|
||||
totpHint: '允许用户使用 Google Authenticator 等应用进行二次验证',
|
||||
totpKeyNotConfigured:
|
||||
@@ -4346,40 +4384,55 @@ export default {
|
||||
usage: '使用方法:在请求头中添加 x-api-key: <your-admin-api-key>'
|
||||
},
|
||||
soraS3: {
|
||||
title: 'Sora S3 存储配置',
|
||||
description: '以多配置列表方式管理 Sora S3 端点,并可切换生效配置',
|
||||
title: 'Sora 存储配置',
|
||||
description: '以多配置列表管理 Sora 媒体存储,支持 S3 和 Google Drive',
|
||||
newProfile: '新建配置',
|
||||
reloadProfiles: '刷新列表',
|
||||
empty: '暂无 Sora S3 配置,请先创建',
|
||||
createTitle: '新建 Sora S3 配置',
|
||||
editTitle: '编辑 Sora S3 配置',
|
||||
empty: '暂无存储配置,请先创建',
|
||||
createTitle: '新建存储配置',
|
||||
editTitle: '编辑存储配置',
|
||||
selectProvider: '选择存储类型',
|
||||
providerS3Desc: 'S3 兼容对象存储',
|
||||
providerGDriveDesc: 'Google Drive 云盘',
|
||||
profileID: '配置 ID',
|
||||
profileName: '配置名称',
|
||||
setActive: '创建后设为生效',
|
||||
saveProfile: '保存配置',
|
||||
activateProfile: '设为生效',
|
||||
profileCreated: 'Sora S3 配置创建成功',
|
||||
profileSaved: 'Sora S3 配置保存成功',
|
||||
profileDeleted: 'Sora S3 配置删除成功',
|
||||
profileActivated: 'Sora S3 生效配置已切换',
|
||||
profileCreated: '存储配置创建成功',
|
||||
profileSaved: '存储配置保存成功',
|
||||
profileDeleted: '存储配置删除成功',
|
||||
profileActivated: '生效配置已切换',
|
||||
profileIDRequired: '请填写配置 ID',
|
||||
profileNameRequired: '请填写配置名称',
|
||||
profileSelectRequired: '请先选择配置',
|
||||
endpointRequired: '启用时必须填写 S3 端点',
|
||||
bucketRequired: '启用时必须填写存储桶',
|
||||
accessKeyRequired: '启用时必须填写 Access Key ID',
|
||||
deleteConfirm: '确定删除 Sora S3 配置 {profileID} 吗?',
|
||||
deleteConfirm: '确定删除存储配置 {profileID} 吗?',
|
||||
columns: {
|
||||
profile: '配置',
|
||||
profileId: 'Profile ID',
|
||||
name: '名称',
|
||||
provider: '存储类型',
|
||||
active: '生效状态',
|
||||
endpoint: '端点',
|
||||
bucket: '存储桶',
|
||||
storagePath: '存储路径',
|
||||
capacityUsage: '容量 / 已用',
|
||||
capacityUnlimited: '无限制',
|
||||
videoCount: '视频数',
|
||||
videoCompleted: '完成',
|
||||
videoInProgress: '进行中',
|
||||
quota: '默认配额',
|
||||
updatedAt: '更新时间',
|
||||
actions: '操作'
|
||||
actions: '操作',
|
||||
rootFolder: '根目录',
|
||||
testInTable: '测试',
|
||||
testingInTable: '测试中...',
|
||||
testTimeout: '测试超时(15秒)'
|
||||
},
|
||||
enabled: '启用 S3 存储',
|
||||
enabledHint: '启用后,Sora 生成的媒体文件将自动上传到 S3 存储',
|
||||
enabled: '启用存储',
|
||||
enabledHint: '启用后,Sora 生成的媒体文件将自动上传到存储',
|
||||
endpoint: 'S3 端点',
|
||||
region: '区域',
|
||||
bucket: '存储桶',
|
||||
@@ -4388,16 +4441,38 @@ export default {
|
||||
secretAccessKey: 'Secret Access Key',
|
||||
secretConfigured: '(已配置,留空保持不变)',
|
||||
cdnUrl: 'CDN URL',
|
||||
cdnUrlHint: '可选,配置后使用 CDN URL 访问文件,否则使用预签名 URL',
|
||||
cdnUrlHint: '可选,配置后使用 CDN URL 访问文件',
|
||||
forcePathStyle: '强制路径风格(Path Style)',
|
||||
defaultQuota: '默认存储配额',
|
||||
defaultQuotaHint: '未在用户或分组级别指定配额时的默认值,0 表示无限制',
|
||||
testConnection: '测试连接',
|
||||
testing: '测试中...',
|
||||
testSuccess: 'S3 连接测试成功',
|
||||
testFailed: 'S3 连接测试失败',
|
||||
saved: 'Sora S3 设置保存成功',
|
||||
saveFailed: '保存 Sora S3 设置失败'
|
||||
testSuccess: '连接测试成功',
|
||||
testFailed: '连接测试失败',
|
||||
saved: '存储设置保存成功',
|
||||
saveFailed: '保存存储设置失败',
|
||||
gdrive: {
|
||||
authType: '认证方式',
|
||||
serviceAccount: '服务账号',
|
||||
clientId: 'Client ID',
|
||||
clientSecret: 'Client Secret',
|
||||
clientSecretConfigured: '(已配置,留空保持不变)',
|
||||
refreshToken: 'Refresh Token',
|
||||
refreshTokenConfigured: '(已配置,留空保持不变)',
|
||||
serviceAccountJson: '服务账号 JSON',
|
||||
serviceAccountConfigured: '(已配置,留空保持不变)',
|
||||
folderId: 'Folder ID(可选)',
|
||||
authorize: '授权 Google Drive',
|
||||
authorizeHint: '通过 OAuth2 获取 Refresh Token',
|
||||
oauthFieldsRequired: '请先填写 Client ID 和 Client Secret',
|
||||
oauthSuccess: 'Google Drive 授权成功',
|
||||
oauthFailed: 'Google Drive 授权失败',
|
||||
closeWindow: '此窗口将自动关闭',
|
||||
processing: '正在处理授权...',
|
||||
testStorage: '测试存储',
|
||||
testSuccess: 'Google Drive 存储测试成功(上传、访问、删除均正常)',
|
||||
testFailed: 'Google Drive 存储测试失败'
|
||||
}
|
||||
},
|
||||
streamTimeout: {
|
||||
title: '流超时处理',
|
||||
@@ -4893,6 +4968,7 @@ export default {
|
||||
downloadLocal: '本地下载',
|
||||
canDownload: '可下载',
|
||||
regenrate: '重新生成',
|
||||
regenerate: '重新生成',
|
||||
creatorPlaceholder: '描述你想要生成的视频或图片...',
|
||||
videoModels: '视频模型',
|
||||
imageModels: '图片模型',
|
||||
@@ -4909,6 +4985,13 @@ export default {
|
||||
galleryEmptyTitle: '还没有任何作品',
|
||||
galleryEmptyDesc: '你的创作成果将会展示在这里。前往生成页,开始你的第一次创作吧。',
|
||||
startCreating: '开始创作',
|
||||
yesterday: '昨天'
|
||||
yesterday: '昨天',
|
||||
landscape: '横屏',
|
||||
portrait: '竖屏',
|
||||
square: '方形',
|
||||
examplePrompt1: '一只金色的柴犬在东京涩谷街头散步,镜头跟随,电影感画面,4K 高清',
|
||||
examplePrompt2: '无人机航拍视角,冰岛极光下的冰川湖面反射绿色光芒,慢速推进',
|
||||
examplePrompt3: '赛博朋克风格的未来城市,霓虹灯倒映在雨后积水中,夜景,电影级色彩',
|
||||
examplePrompt4: '水墨画风格,一叶扁舟在山水间漂泊,薄雾缭绕,中国古典意境'
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,6 +403,8 @@ export interface AdminGroup extends Group {
|
||||
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
mcp_xml_inject: boolean
|
||||
// Claude usage 模拟开关(仅 anthropic 平台使用)
|
||||
simulate_claude_max_enabled: boolean
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
supported_model_scopes?: string[]
|
||||
@@ -497,6 +499,7 @@ export interface CreateGroupRequest {
|
||||
fallback_group_id?: number | null
|
||||
fallback_group_id_on_invalid_request?: number | null
|
||||
mcp_xml_inject?: boolean
|
||||
simulate_claude_max_enabled?: boolean
|
||||
supported_model_scopes?: string[]
|
||||
// 从指定分组复制账号
|
||||
copy_accounts_from_group_ids?: number[]
|
||||
@@ -525,6 +528,7 @@ export interface UpdateGroupRequest {
|
||||
fallback_group_id?: number | null
|
||||
fallback_group_id_on_invalid_request?: number | null
|
||||
mcp_xml_inject?: boolean
|
||||
simulate_claude_max_enabled?: boolean
|
||||
supported_model_scopes?: string[]
|
||||
copy_accounts_from_group_ids?: number[]
|
||||
}
|
||||
@@ -720,6 +724,12 @@ export interface Account {
|
||||
cache_ttl_override_enabled?: boolean | null
|
||||
cache_ttl_override_target?: string | null
|
||||
|
||||
// 客户端亲和调度(仅 Anthropic/Antigravity 平台有效)
|
||||
// 启用后新会话会优先调度到客户端之前使用过的账号
|
||||
client_affinity_enabled?: boolean | null
|
||||
affinity_client_count?: number | null
|
||||
affinity_clients?: string[] | null
|
||||
|
||||
// API Key 账号配额限制
|
||||
quota_limit?: number | null
|
||||
quota_used?: number | null
|
||||
@@ -780,6 +790,11 @@ export interface AccountUsageInfo {
|
||||
gemini_pro_minute?: UsageProgress | null
|
||||
gemini_flash_minute?: UsageProgress | null
|
||||
antigravity_quota?: Record<string, AntigravityModelQuota> | null
|
||||
ai_credits?: Array<{
|
||||
credit_type?: string
|
||||
amount?: number
|
||||
minimum_balance?: number
|
||||
}> | null
|
||||
// Antigravity 403 forbidden 状态
|
||||
is_forbidden?: boolean
|
||||
forbidden_reason?: string
|
||||
@@ -962,6 +977,8 @@ export interface UsageLog {
|
||||
model: string
|
||||
service_tier?: string | null
|
||||
reasoning_effort?: string | null
|
||||
inbound_endpoint?: string | null
|
||||
upstream_endpoint?: string | null
|
||||
|
||||
group_id: number | null
|
||||
subscription_id: number | null
|
||||
@@ -1168,6 +1185,14 @@ export interface ModelStat {
|
||||
actual_cost: number // 实际扣除
|
||||
}
|
||||
|
||||
export interface EndpointStat {
|
||||
endpoint: string
|
||||
requests: number
|
||||
total_tokens: number
|
||||
cost: number
|
||||
actual_cost: number
|
||||
}
|
||||
|
||||
export interface GroupStat {
|
||||
group_id: number
|
||||
group_name: string
|
||||
@@ -1199,6 +1224,8 @@ export interface UserSpendingRankingItem {
|
||||
export interface UserSpendingRankingResponse {
|
||||
ranking: UserSpendingRankingItem[]
|
||||
total_actual_cost: number
|
||||
total_requests: number
|
||||
total_tokens: number
|
||||
start_date: string
|
||||
end_date: string
|
||||
}
|
||||
@@ -1362,6 +1389,8 @@ export interface AccountUsageStatsResponse {
|
||||
history: AccountUsageHistory[]
|
||||
summary: AccountUsageSummary
|
||||
models: ModelStat[]
|
||||
endpoints: EndpointStat[]
|
||||
upstream_endpoints: EndpointStat[]
|
||||
}
|
||||
|
||||
// ==================== User Attribute Types ====================
|
||||
|
||||
@@ -241,6 +241,8 @@
|
||||
:enable-ranking-view="true"
|
||||
:ranking-items="rankingItems"
|
||||
:ranking-total-actual-cost="rankingTotalActualCost"
|
||||
:ranking-total-requests="rankingTotalRequests"
|
||||
:ranking-total-tokens="rankingTotalTokens"
|
||||
:loading="chartsLoading"
|
||||
:ranking-loading="rankingLoading"
|
||||
:ranking-error="rankingError"
|
||||
@@ -334,6 +336,8 @@ const modelStats = ref<ModelStat[]>([])
|
||||
const userTrend = ref<UserUsageTrendPoint[]>([])
|
||||
const rankingItems = ref<UserSpendingRankingItem[]>([])
|
||||
const rankingTotalActualCost = ref(0)
|
||||
const rankingTotalRequests = ref(0)
|
||||
const rankingTotalTokens = ref(0)
|
||||
let chartLoadSeq = 0
|
||||
let usersTrendLoadSeq = 0
|
||||
let rankingLoadSeq = 0
|
||||
@@ -347,7 +351,7 @@ const formatLocalDate = (date: Date): string => {
|
||||
const getTodayLocalDate = () => formatLocalDate(new Date())
|
||||
|
||||
// Date range
|
||||
const granularity = ref<'day' | 'hour'>('day')
|
||||
const granularity = ref<'day' | 'hour'>('hour')
|
||||
const startDate = ref(getTodayLocalDate())
|
||||
const endDate = ref(getTodayLocalDate())
|
||||
|
||||
@@ -630,11 +634,15 @@ const loadUserSpendingRanking = async () => {
|
||||
if (currentSeq !== rankingLoadSeq) return
|
||||
rankingItems.value = response.ranking || []
|
||||
rankingTotalActualCost.value = response.total_actual_cost || 0
|
||||
rankingTotalRequests.value = response.total_requests || 0
|
||||
rankingTotalTokens.value = response.total_tokens || 0
|
||||
} catch (error) {
|
||||
if (currentSeq !== rankingLoadSeq) return
|
||||
console.error('Error loading user spending ranking:', error)
|
||||
rankingItems.value = []
|
||||
rankingTotalActualCost.value = 0
|
||||
rankingTotalRequests.value = 0
|
||||
rankingTotalTokens.value = 0
|
||||
rankingError.value = true
|
||||
} finally {
|
||||
if (currentSeq === rankingLoadSeq) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user