mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cadca752c4 | ||
|
|
edf215e6fd | ||
|
|
e12dd079fd | ||
|
|
04a509d45e | ||
|
|
269a659200 | ||
|
|
2c31bf46b5 | ||
|
|
8f6639f825 | ||
|
|
fc17d9d7df | ||
|
|
ab092e88a8 | ||
|
|
56a1e29cdd | ||
|
|
0059a232a6 | ||
|
|
45676fdc8d | ||
|
|
e32c5f534f | ||
|
|
426d691c95 | ||
|
|
e9a4c8ab97 | ||
|
|
a55cfebd09 | ||
|
|
34cc02f8c7 | ||
|
|
624d9fddb7 | ||
|
|
47fbe43324 | ||
|
|
1245f07a2d | ||
|
|
839975b0cf | ||
|
|
8c1233393f | ||
|
|
9cdb0568cc | ||
|
|
74e05b83ea | ||
|
|
4ded9e7d49 | ||
|
|
716272a1e2 | ||
|
|
9cc8352593 | ||
|
|
43a1031e38 | ||
|
|
a5547b2f30 | ||
|
|
b0aa23540b | ||
|
|
ffaa6c4a17 | ||
|
|
fbf72f0ec4 | ||
|
|
909b8a8f9c | ||
|
|
4a0fe3b143 | ||
|
|
a1292fac81 | ||
|
|
7f98be4f91 | ||
|
|
fd73b8875d | ||
|
|
f9ab1daa3c | ||
|
|
d27b847442 | ||
|
|
dac6bc2228 | ||
|
|
4bd3dbf2ce | ||
|
|
226df1c23a | ||
|
|
2665230a09 | ||
|
|
4f0c2b794c | ||
|
|
e756064c19 | ||
|
|
17dfb0af01 | ||
|
|
ff74f517df | ||
|
|
477a9a180f | ||
|
|
da48df06d2 | ||
|
|
39fad63ccf | ||
|
|
5602d02b1b | ||
|
|
81989eed1c | ||
|
|
192efb84a0 | ||
|
|
8672347f93 | ||
|
|
5e5d4a513b | ||
|
|
88b6358472 | ||
|
|
dd8d5e2c42 | ||
|
|
d91e2328fb | ||
|
|
2a16735495 | ||
|
|
292f25f9ca | ||
|
|
c92e37775a | ||
|
|
f6ed3d1456 | ||
|
|
84686753e8 | ||
|
|
91f01309da | ||
|
|
57a1fc9d33 | ||
|
|
c95a864975 | ||
|
|
7a83db6180 | ||
|
|
a8513da7ff | ||
|
|
53534d3956 | ||
|
|
cc07a0e295 | ||
|
|
e7bc62500b | ||
|
|
c8fb9ef3a5 | ||
|
|
eb5e6214bc | ||
|
|
568d6ee10e | ||
|
|
6aef1af76e | ||
|
|
a54852e129 | ||
|
|
668118def1 | ||
|
|
73e6b160f8 | ||
|
|
6fec141de6 | ||
|
|
31cde6c555 | ||
|
|
b1a980f344 | ||
|
|
00d9fbd220 | ||
|
|
4f4c9679bf | ||
|
|
3dab71729d | ||
|
|
2f6f758670 | ||
|
|
090c8981dd | ||
|
|
fbb572948d |
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -222,8 +222,9 @@ jobs:
|
||||
REPO="${{ github.repository }}"
|
||||
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
|
||||
|
||||
# 获取 tag message 内容
|
||||
# 获取 tag message 内容并转义 Markdown 特殊字符
|
||||
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
|
||||
TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
|
||||
|
||||
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
|
||||
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
|
||||
|
||||
@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://v2.pincc.ai/**
|
||||
Try Sub2API online: **https://demo.sub2api.org/**
|
||||
|
||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.46
|
||||
0.1.61
|
||||
|
||||
@@ -70,6 +70,7 @@ func provideCleanup(
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
@@ -138,6 +139,10 @@ func provideCleanup(
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"SubscriptionExpiryService", func() error {
|
||||
subscriptionExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
@@ -84,7 +90,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
@@ -105,21 +112,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -128,7 +136,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
@@ -137,7 +144,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
@@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
@@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -211,6 +219,7 @@ func provideCleanup(
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
@@ -278,6 +287,10 @@ func provideCleanup(
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"SubscriptionExpiryService", func() error {
|
||||
subscriptionExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -610,6 +610,9 @@ var (
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||
}
|
||||
// UsersTable holds the schema information for the "users" table.
|
||||
UsersTable = &schema.Table{
|
||||
|
||||
@@ -14360,6 +14360,9 @@ type UserMutation struct {
|
||||
status *string
|
||||
username *string
|
||||
notes *string
|
||||
totp_secret_encrypted *string
|
||||
totp_enabled *bool
|
||||
totp_enabled_at *time.Time
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() {
|
||||
m.notes = nil
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (m *UserMutation) SetTotpSecretEncrypted(s string) {
|
||||
m.totp_secret_encrypted = &s
|
||||
}
|
||||
|
||||
// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
|
||||
func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
|
||||
v := m.totp_secret_encrypted
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
|
||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
|
||||
}
|
||||
return oldValue.TotpSecretEncrypted, nil
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (m *UserMutation) ClearTotpSecretEncrypted() {
|
||||
m.totp_secret_encrypted = nil
|
||||
m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
|
||||
func (m *UserMutation) TotpSecretEncryptedCleared() bool {
|
||||
_, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
|
||||
func (m *UserMutation) ResetTotpSecretEncrypted() {
|
||||
m.totp_secret_encrypted = nil
|
||||
delete(m.clearedFields, user.FieldTotpSecretEncrypted)
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (m *UserMutation) SetTotpEnabled(b bool) {
|
||||
m.totp_enabled = &b
|
||||
}
|
||||
|
||||
// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
|
||||
func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
|
||||
v := m.totp_enabled
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
|
||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
|
||||
}
|
||||
return oldValue.TotpEnabled, nil
|
||||
}
|
||||
|
||||
// ResetTotpEnabled resets all changes to the "totp_enabled" field.
|
||||
func (m *UserMutation) ResetTotpEnabled() {
|
||||
m.totp_enabled = nil
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
|
||||
m.totp_enabled_at = &t
|
||||
}
|
||||
|
||||
// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
|
||||
func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
|
||||
v := m.totp_enabled_at
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
|
||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
|
||||
}
|
||||
return oldValue.TotpEnabledAt, nil
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (m *UserMutation) ClearTotpEnabledAt() {
|
||||
m.totp_enabled_at = nil
|
||||
m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
|
||||
}
|
||||
|
||||
// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
|
||||
func (m *UserMutation) TotpEnabledAtCleared() bool {
|
||||
_, ok := m.clearedFields[user.FieldTotpEnabledAt]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
|
||||
func (m *UserMutation) ResetTotpEnabledAt() {
|
||||
m.totp_enabled_at = nil
|
||||
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UserMutation) Fields() []string {
|
||||
fields := make([]string, 0, 11)
|
||||
fields := make([]string, 0, 14)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, user.FieldCreatedAt)
|
||||
}
|
||||
@@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string {
|
||||
if m.notes != nil {
|
||||
fields = append(fields, user.FieldNotes)
|
||||
}
|
||||
if m.totp_secret_encrypted != nil {
|
||||
fields = append(fields, user.FieldTotpSecretEncrypted)
|
||||
}
|
||||
if m.totp_enabled != nil {
|
||||
fields = append(fields, user.FieldTotpEnabled)
|
||||
}
|
||||
if m.totp_enabled_at != nil {
|
||||
fields = append(fields, user.FieldTotpEnabledAt)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Username()
|
||||
case user.FieldNotes:
|
||||
return m.Notes()
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
return m.TotpSecretEncrypted()
|
||||
case user.FieldTotpEnabled:
|
||||
return m.TotpEnabled()
|
||||
case user.FieldTotpEnabledAt:
|
||||
return m.TotpEnabledAt()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
||||
return m.OldUsername(ctx)
|
||||
case user.FieldNotes:
|
||||
return m.OldNotes(ctx)
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
return m.OldTotpSecretEncrypted(ctx)
|
||||
case user.FieldTotpEnabled:
|
||||
return m.OldTotpEnabled(ctx)
|
||||
case user.FieldTotpEnabledAt:
|
||||
return m.OldTotpEnabledAt(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
@@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetNotes(v)
|
||||
return nil
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetTotpSecretEncrypted(v)
|
||||
return nil
|
||||
case user.FieldTotpEnabled:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetTotpEnabled(v)
|
||||
return nil
|
||||
case user.FieldTotpEnabledAt:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetTotpEnabledAt(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
@@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(user.FieldDeletedAt) {
|
||||
fields = append(fields, user.FieldDeletedAt)
|
||||
}
|
||||
if m.FieldCleared(user.FieldTotpSecretEncrypted) {
|
||||
fields = append(fields, user.FieldTotpSecretEncrypted)
|
||||
}
|
||||
if m.FieldCleared(user.FieldTotpEnabledAt) {
|
||||
fields = append(fields, user.FieldTotpEnabledAt)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error {
|
||||
case user.FieldDeletedAt:
|
||||
m.ClearDeletedAt()
|
||||
return nil
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
m.ClearTotpSecretEncrypted()
|
||||
return nil
|
||||
case user.FieldTotpEnabledAt:
|
||||
m.ClearTotpEnabledAt()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User nullable field %s", name)
|
||||
}
|
||||
@@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error {
|
||||
case user.FieldNotes:
|
||||
m.ResetNotes()
|
||||
return nil
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
m.ResetTotpSecretEncrypted()
|
||||
return nil
|
||||
case user.FieldTotpEnabled:
|
||||
m.ResetTotpEnabled()
|
||||
return nil
|
||||
case user.FieldTotpEnabledAt:
|
||||
m.ResetTotpEnabledAt()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown User field %s", name)
|
||||
}
|
||||
|
||||
@@ -736,6 +736,10 @@ func init() {
|
||||
userDescNotes := userFields[7].Descriptor()
|
||||
// user.DefaultNotes holds the default value on creation for the notes field.
|
||||
user.DefaultNotes = userDescNotes.Default.(string)
|
||||
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
|
||||
userDescTotpEnabled := userFields[9].Descriptor()
|
||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||
_ = userallowedgroupFields
|
||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||
|
||||
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
|
||||
field.String("notes").
|
||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||
Default(""),
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
field.String("totp_secret_encrypted").
|
||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Bool("totp_enabled").
|
||||
Default(false),
|
||||
field.Time("totp_enabled_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,12 @@ type User struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
// Notes holds the value of the "notes" field.
|
||||
Notes string `json:"notes,omitempty"`
|
||||
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
|
||||
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
|
||||
// TotpEnabled holds the value of the "totp_enabled" field.
|
||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserQuery when eager-loading is set.
|
||||
Edges UserEdges `json:"edges"`
|
||||
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case user.FieldTotpEnabled:
|
||||
values[i] = new(sql.NullBool)
|
||||
case user.FieldBalance:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case user.FieldID, user.FieldConcurrency:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||
values[i] = new(sql.NullString)
|
||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
|
||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Notes = value.String
|
||||
}
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpSecretEncrypted = new(string)
|
||||
*_m.TotpSecretEncrypted = value.String
|
||||
}
|
||||
case user.FieldTotpEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpEnabled = value.Bool
|
||||
}
|
||||
case user.FieldTotpEnabledAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpEnabledAt = new(time.Time)
|
||||
*_m.TotpEnabledAt = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -367,6 +395,19 @@ func (_m *User) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notes=")
|
||||
builder.WriteString(_m.Notes)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.TotpSecretEncrypted; v != nil {
|
||||
builder.WriteString("totp_secret_encrypted=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("totp_enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.TotpEnabledAt; v != nil {
|
||||
builder.WriteString("totp_enabled_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -37,6 +37,12 @@ const (
|
||||
FieldUsername = "username"
|
||||
// FieldNotes holds the string denoting the notes field in the database.
|
||||
FieldNotes = "notes"
|
||||
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
|
||||
FieldTotpSecretEncrypted = "totp_secret_encrypted"
|
||||
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
|
||||
FieldTotpEnabled = "totp_enabled"
|
||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||
FieldTotpEnabledAt = "totp_enabled_at"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -134,6 +140,9 @@ var Columns = []string{
|
||||
FieldStatus,
|
||||
FieldUsername,
|
||||
FieldNotes,
|
||||
FieldTotpSecretEncrypted,
|
||||
FieldTotpEnabled,
|
||||
FieldTotpEnabledAt,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -188,6 +197,8 @@ var (
|
||||
UsernameValidator func(string) error
|
||||
// DefaultNotes holds the default value on creation for the "notes" field.
|
||||
DefaultNotes string
|
||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||
DefaultTotpEnabled bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the User queries.
|
||||
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
|
||||
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpEnabled orders the results by the totp_enabled field.
|
||||
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
|
||||
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
|
||||
func TotpSecretEncrypted(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
|
||||
func TotpEnabled(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
|
||||
func TotpEnabledAt(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedEQ(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNEQ(v string) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedIn(vs ...string) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedGT(v string) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedGTE(v string) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedLT(v string) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedLTE(v string) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedContains(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
|
||||
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
|
||||
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedIsNil() predicate.User {
|
||||
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNotNil() predicate.User {
|
||||
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedEqualFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedContainsFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
|
||||
func TotpEnabledEQ(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
|
||||
func TotpEnabledNEQ(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
|
||||
}
|
||||
|
||||
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtGT(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtGTE(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtLT(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtLTE(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtIsNil() predicate.User {
|
||||
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNotNil() predicate.User {
|
||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
|
||||
_c.mutation.SetTotpSecretEncrypted(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
|
||||
_c.mutation.SetTotpEnabled(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpEnabled(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
|
||||
_c.mutation.SetTotpEnabledAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
|
||||
v := user.DefaultNotes
|
||||
_c.mutation.SetNotes(v)
|
||||
}
|
||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||
v := user.DefaultTotpEnabled
|
||||
_c.mutation.SetTotpEnabled(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
|
||||
if _, ok := _c.mutation.Notes(); !ok {
|
||||
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
_node.Notes = value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
_node.TotpSecretEncrypted = &value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
_node.TotpEnabled = value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
_node.TotpEnabledAt = &value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
|
||||
u.Set(user.FieldTotpSecretEncrypted, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpSecretEncrypted)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
|
||||
u.SetNull(user.FieldTotpSecretEncrypted)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
|
||||
u.Set(user.FieldTotpEnabled, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpEnabled)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
|
||||
u.Set(user.FieldTotpEnabledAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpEnabledAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
||||
u.SetNull(user.FieldTotpEnabledAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpSecretEncrypted(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabledAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpSecretEncrypted(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabledAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
|
||||
_u.mutation.SetTotpSecretEncrypted(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
|
||||
_u.mutation.ClearTotpSecretEncrypted()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
|
||||
_u.mutation.SetTotpEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
|
||||
_u.mutation.SetTotpEnabledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
||||
_u.mutation.ClearTotpEnabledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
|
||||
_u.mutation.SetTotpSecretEncrypted(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
|
||||
_u.mutation.ClearTotpSecretEncrypted()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
|
||||
_u.mutation.SetTotpEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
|
||||
_u.mutation.SetTotpEnabledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
||||
_u.mutation.ClearTotpEnabledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -37,6 +37,7 @@ require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -106,6 +107,7 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
|
||||
@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
||||
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
|
||||
@@ -47,6 +47,7 @@ type Config struct {
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
@@ -466,6 +467,16 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
type TotpConfig struct {
|
||||
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
|
||||
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
EncryptionKeyConfigured bool `mapstructure:"-"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
@@ -626,6 +637,20 @@ func Load() (*Config, error) {
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
key, err := generateJWTSecret(32) // Reuse the same random generation function
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
@@ -756,6 +781,9 @@ func setDefaults() {
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
|
||||
// Default
|
||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||
|
||||
@@ -45,6 +45,7 @@ type AccountHandler struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
sessionLimitCache service.SessionLimitCache
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
@@ -60,6 +61,7 @@ func NewAccountHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
sessionLimitCache service.SessionLimitCache,
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
@@ -73,6 +75,7 @@ func NewAccountHandler(
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询)
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
@@ -542,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
@@ -552,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -606,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
@@ -655,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
|
||||
@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
|
||||
@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
|
||||
@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -68,6 +72,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -87,8 +94,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -111,13 +121,16 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -194,6 +207,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TOTP 双因素认证参数验证
|
||||
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
@@ -223,6 +246,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// “购买订阅”页面配置验证
|
||||
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
|
||||
if req.PurchaseSubscriptionEnabled != nil {
|
||||
purchaseEnabled = *req.PurchaseSubscriptionEnabled
|
||||
}
|
||||
purchaseURL := previousSettings.PurchaseSubscriptionURL
|
||||
if req.PurchaseSubscriptionURL != nil {
|
||||
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
|
||||
}
|
||||
|
||||
// - 启用时要求 URL 合法且非空
|
||||
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
|
||||
if purchaseEnabled {
|
||||
if purchaseURL == "" {
|
||||
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
} else if purchaseURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -236,38 +287,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -311,6 +368,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@@ -332,6 +393,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -376,6 +440,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
@@ -439,6 +509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.HomeContent != after.HomeContent {
|
||||
changed = append(changed, "home_content")
|
||||
}
|
||||
if before.HideCcsImportButton != after.HideCcsImportButton {
|
||||
changed = append(changed, "hide_ccs_import_button")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
|
||||
type AdjustSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// Extend handles adjusting a subscription (extend or shorten)
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
var req AdjustSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
@@ -163,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
out := make([]dto.AdminUsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
out := make([]dto.AdminUser, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
@@ -18,16 +20,18 @@ type AuthHandler struct {
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
// Create a temporary login session for 2FA
|
||||
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpLoginResponse{
|
||||
Requires2FA: true,
|
||||
TempToken: tempToken,
|
||||
UserEmailMasked: service.MaskEmail(user.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
type TotpLoginResponse struct {
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||
}
|
||||
|
||||
// Login2FARequest represents the 2FA login request
|
||||
type Login2FARequest struct {
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
}
|
||||
|
||||
// Login2FA completes the login with 2FA verification
|
||||
// POST /api/v1/auth/login/2fa
|
||||
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
var req Login2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_request",
|
||||
"temp_token_len", len(req.TempToken),
|
||||
"totp_code_len", len(req.TotpCode))
|
||||
|
||||
// Get the login session
|
||||
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||
if err != nil || session == nil {
|
||||
tokenPrefix := ""
|
||||
if len(req.TempToken) >= 8 {
|
||||
tokenPrefix = req.TempToken[:8]
|
||||
}
|
||||
slog.Debug("login_2fa_session_invalid",
|
||||
"temp_token_prefix", tokenPrefix,
|
||||
"error", err)
|
||||
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_session_found",
|
||||
"user_id", session.UserID,
|
||||
"email", session.Email)
|
||||
|
||||
// Verify the TOTP code
|
||||
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||
slog.Debug("login_2fa_verify_failed",
|
||||
"user_id", session.UserID,
|
||||
"error", err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
// 检查优惠码功能是否启用
|
||||
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_DISABLED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// ForgotPasswordResponse 忘记密码响应
|
||||
type ForgotPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ForgotPassword 请求密码重置
|
||||
// POST /api/v1/auth/forgot-password
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ForgotPasswordResponse{
|
||||
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// ResetPasswordResponse 重置密码响应
|
||||
type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// POST /api/v1/auth/reset-password
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Reset password
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ResetPasswordResponse{
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
// UserFromServiceAdmin converts a service User to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
base := UserFromService(u)
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
out := groupFromServiceBase(g)
|
||||
return &out
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
return GroupFromServiceShallow(g)
|
||||
}
|
||||
|
||||
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
|
||||
// It includes internal fields like model_routing and account_count.
|
||||
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
|
||||
return out
|
||||
}
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
@@ -273,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
out := redeemCodeFromServiceBase(rc)
|
||||
return &out
|
||||
}
|
||||
|
||||
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminRedeemCode{
|
||||
RedeemCode: redeemCodeFromServiceBase(rc),
|
||||
Notes: rc.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
return RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
@@ -281,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
@@ -302,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
}
|
||||
}
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
result := &UsageLog{
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -331,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
@@ -342,30 +383,33 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: account,
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
u := usageLogFromServiceUser(l)
|
||||
return &u
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
||||
@@ -414,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
out := userSubscriptionFromServiceBase(sub)
|
||||
return &out
|
||||
}
|
||||
|
||||
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
|
||||
// It includes assignment metadata and notes.
|
||||
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUserSubscription{
|
||||
UserSubscription: userSubscriptionFromServiceBase(sub),
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
|
||||
return UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
@@ -427,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
|
||||
@@ -2,8 +2,12 @@ package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -22,13 +26,16 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -52,19 +59,25 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
|
||||
@@ -6,7 +6,6 @@ type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
@@ -19,6 +18,14 @@ type User struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
|
||||
type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -58,13 +65,19 @@ type Group struct {
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
|
||||
type AdminGroup struct {
|
||||
Group
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
@@ -180,7 +193,6 @@ type RedeemCode struct {
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
@@ -190,6 +202,15 @@ type RedeemCode struct {
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等内部信息。
|
||||
type AdminRedeemCode struct {
|
||||
RedeemCode
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -209,14 +230,13 @@ type UsageLog struct {
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
@@ -230,18 +250,28 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
// IPAddress 用户请求 IP(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
// Account 最小账号信息(避免泄露敏感字段)
|
||||
Account *AccountSummary `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
@@ -300,23 +330,30 @@ type UserSubscription struct {
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
|
||||
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
|
||||
type AdminUserSubscription struct {
|
||||
UserSubscription
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -765,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||
func isWarmupRequest(body []byte) bool {
|
||||
// 快速检查:如果body不包含关键字,直接返回false
|
||||
// InterceptType 表示请求拦截类型
|
||||
type InterceptType int
|
||||
|
||||
const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
)
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||
return false
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||||
|
||||
if !hasSuggestionMode && !hasWarmupKeyword {
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 解析完整请求
|
||||
// 解析请求(只解析一次)
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
@@ -786,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
|
||||
} `json:"system"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return true
|
||||
// 检查 SUGGESTION MODE(最后一条 user 消息)
|
||||
if hasSuggestionMode && len(req.Messages) > 0 {
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||||
lastMsg.Content[0].Type == "text" &&
|
||||
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||||
return InterceptTypeSuggestionMode
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Warmup 请求
|
||||
if hasWarmupKeyword {
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, sys := range req.System {
|
||||
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, system := range req.System {
|
||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 根据拦截类型决定响应内容
|
||||
var msgID string
|
||||
var outputTokens int
|
||||
var textDeltas []string
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
outputTokens = 1
|
||||
textDeltas = []string{""} // 空内容
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
outputTokens = 2
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
@@ -837,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||
}
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||||
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||||
)
|
||||
|
||||
for _, event := range events {
|
||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||
c.Writer.Flush()
|
||||
@@ -854,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
privilegedUserID string
|
||||
wantEmpty bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "with privileged-user-id and tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: false,
|
||||
wantHash: func() string {
|
||||
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "without privileged-user-id but with tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "",
|
||||
wantEmpty: false,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "without tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 创建测试上下文
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||
if tt.privilegedUserID != "" {
|
||||
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||
}
|
||||
|
||||
// 调用函数
|
||||
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||
|
||||
// 验证结果
|
||||
if tt.wantEmpty {
|
||||
require.Empty(t, result, "expected empty session hash")
|
||||
} else {
|
||||
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMatch bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "valid tmp dir path",
|
||||
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "valid tmp dir path in text",
|
||||
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "invalid hash length",
|
||||
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no tmp dir",
|
||||
input: "Hello world",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||
if tt.wantMatch {
|
||||
require.NotNil(t, match, "expected regex to match")
|
||||
require.Len(t, match, 2, "expected 2 capture groups")
|
||||
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||
} else {
|
||||
require.Nil(t, match, "expected regex not to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +23,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionBoundAccountID == 0 {
|
||||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||
sessionBoundAccountID = account.ID
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
// 会话标识生成策略:
|
||||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||
//
|
||||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||
//
|
||||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||
// 1. 从请求体中提取 tmp 目录哈希
|
||||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||
if len(match) < 2 {
|
||||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||
}
|
||||
tmpDirHash := string(match[1])
|
||||
|
||||
// 2. 提取 privileged-user-id
|
||||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||
|
||||
// 3. 组合生成最终的 session hash
|
||||
if privilegedUserID != "" {
|
||||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||
combined := privilegedUserID + ":" + tmpDirHash
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||
return tmpDirHash
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ type Handlers struct {
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TotpHandler handles TOTP-related requests
|
||||
type TotpHandler struct {
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewTotpHandler creates a new TotpHandler
|
||||
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||
return &TotpHandler{
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
// TotpStatusResponse represents the TOTP status response
|
||||
type TotpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for the current user
|
||||
// GET /api/v1/user/totp/status
|
||||
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := TotpStatusResponse{
|
||||
Enabled: status.Enabled,
|
||||
FeatureEnabled: status.FeatureEnabled,
|
||||
}
|
||||
|
||||
if status.EnabledAt != nil {
|
||||
ts := status.EnabledAt.Unix()
|
||||
resp.EnabledAt = &ts
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||
type TotpSetupRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the TOTP setup response
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"`
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// POST /api/v1/user/totp/setup
|
||||
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpSetupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body (optional params)
|
||||
req = TotpSetupRequest{}
|
||||
}
|
||||
|
||||
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpSetupResponse{
|
||||
Secret: result.Secret,
|
||||
QRCodeURL: result.QRCodeURL,
|
||||
SetupToken: result.SetupToken,
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// TotpEnableRequest represents the request to enable TOTP
|
||||
type TotpEnableRequest struct {
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
SetupToken string `json:"setup_token" binding:"required"`
|
||||
}
|
||||
|
||||
// Enable completes the TOTP setup
|
||||
// POST /api/v1/user/totp/enable
|
||||
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpEnableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TotpDisableRequest represents the request to disable TOTP
|
||||
type TotpDisableRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Disable disables TOTP for the current user
|
||||
// POST /api/v1/user/totp/disable
|
||||
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpDisableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
// GET /api/v1/user/totp/verification-method
|
||||
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||
response.Success(c, method)
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
// POST /api/v1/user/totp/send-code
|
||||
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ func ProvideHandlers(
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -82,6 +83,7 @@ func ProvideHandlers(
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
// signature 处理:
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
},
|
||||
}
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
params := cleanJSONSchema(inputSchema)
|
||||
// 1. 深度清理 [undefined] 值
|
||||
DeepCleanUndefined(inputSchema)
|
||||
// 2. 转换为符合 Gemini v1internal 的 schema
|
||||
params := CleanJSONSchema(inputSchema)
|
||||
// 为 nil schema 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "OBJECT",
|
||||
"type": "object", // lowercase type
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
FunctionDeclarations: funcDecls,
|
||||
}}
|
||||
}
|
||||
|
||||
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
|
||||
func cleanJSONSchema(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
cleaned := cleanSchemaValue(schema, "$")
|
||||
result, ok := cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保有 type 字段(默认 OBJECT)
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "OBJECT"
|
||||
}
|
||||
|
||||
// 确保有 properties 字段(默认空对象)
|
||||
if _, hasProps := result["properties"]; !hasProps {
|
||||
result["properties"] = make(map[string]any)
|
||||
}
|
||||
|
||||
// 验证 required 中的字段都存在于 properties 中
|
||||
if required, ok := result["required"].([]any); ok {
|
||||
if props, ok := result["properties"].(map[string]any); ok {
|
||||
validRequired := make([]any, 0, len(required))
|
||||
for _, r := range required {
|
||||
if reqName, ok := r.(string); ok {
|
||||
if _, exists := props[reqName]; exists {
|
||||
validRequired = append(validRequired, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(validRequired) > 0 {
|
||||
result["required"] = validRequired
|
||||
} else {
|
||||
delete(result, "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
var schemaValidationKeys = map[string]bool{
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
}
|
||||
|
||||
var warnedSchemaKeys sync.Map
|
||||
|
||||
func schemaCleaningWarningsEnabled() bool {
|
||||
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
|
||||
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
}
|
||||
// 默认:非 release 模式下输出(debug/test)
|
||||
return gin.Mode() != gin.ReleaseMode
|
||||
}
|
||||
|
||||
func warnSchemaKeyRemovedOnce(key, path string) {
|
||||
if !schemaCleaningWarningsEnabled() {
|
||||
return
|
||||
}
|
||||
if !schemaValidationKeys[key] {
|
||||
return
|
||||
}
|
||||
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
|
||||
return
|
||||
}
|
||||
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
|
||||
}
|
||||
|
||||
// excludedSchemaKeys 不支持的 schema 字段
|
||||
// 基于 Claude API (Vertex AI) 的实际支持情况
|
||||
// 支持: type, description, enum, properties, required, additionalProperties, items
|
||||
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
||||
var excludedSchemaKeys = map[string]bool{
|
||||
// 元 schema 字段
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
|
||||
// 字符串验证(Gemini 不支持)
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
|
||||
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
|
||||
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
|
||||
// 组合 schema(Gemini 不支持)
|
||||
"oneOf": true,
|
||||
"anyOf": true,
|
||||
"allOf": true,
|
||||
"not": true,
|
||||
"if": true,
|
||||
"then": true,
|
||||
"else": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
|
||||
// 对象验证(仅保留 properties/required/additionalProperties)
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
|
||||
// 其他不支持的字段
|
||||
"default": true,
|
||||
"const": true,
|
||||
"examples": true,
|
||||
"deprecated": true,
|
||||
"readOnly": true,
|
||||
"writeOnly": true,
|
||||
"contentMediaType": true,
|
||||
"contentEncoding": true,
|
||||
|
||||
// Claude 特有字段
|
||||
"strict": true,
|
||||
}
|
||||
|
||||
// cleanSchemaValue 递归清理 schema 值
|
||||
func cleanSchemaValue(value any, path string) any {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
// 跳过不支持的字段
|
||||
if excludedSchemaKeys[k] {
|
||||
warnSchemaKeyRemovedOnce(k, path)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 type 字段
|
||||
if k == "type" {
|
||||
result[k] = cleanTypeValue(val)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
||||
if k == "format" {
|
||||
if formatStr, ok := val.(string); ok {
|
||||
// Gemini 只支持 date-time, date, time
|
||||
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
||||
result[k] = val
|
||||
}
|
||||
// 其他 format 值直接跳过
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
||||
if k == "additionalProperties" {
|
||||
if boolVal, ok := val.(bool); ok {
|
||||
result[k] = boolVal
|
||||
} else {
|
||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||
result[k] = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 递归清理所有值
|
||||
result[k] = cleanSchemaValue(val, path+"."+k)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
cleaned := make([]any, 0, len(v))
|
||||
for i, item := range v {
|
||||
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
|
||||
}
|
||||
return cleaned
|
||||
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// cleanTypeValue 处理 type 字段,转换为大写
|
||||
func cleanTypeValue(value any) any {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return strings.ToUpper(v)
|
||||
case []any:
|
||||
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
|
||||
for _, t := range v {
|
||||
if ts, ok := t.(string); ok && ts != "null" {
|
||||
return strings.ToUpper(ts)
|
||||
}
|
||||
}
|
||||
// 如果只有 null,返回 STRING
|
||||
return "STRING"
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
|
||||
contentNoSig := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
|
||||
]`
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package antigravity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
} else if len(v1Resp.Response.Candidates) == 0 {
|
||||
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
// 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
} else {
|
||||
// 普通 text (无签名) - 累积到 builder
|
||||
p.textBuilder += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
var finishReason string
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason = geminiResp.Candidates[0].FinishReason
|
||||
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
|
||||
if geminiResp.Candidates[0].Content != nil {
|
||||
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
|
||||
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
|
||||
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
|
||||
func CleanJSONSchema(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
// 0. 预处理:展开 $ref (Schema Flattening)
|
||||
// (Go map 是引用的,直接修改 schema)
|
||||
flattenRefs(schema, extractDefs(schema))
|
||||
|
||||
// 递归清理
|
||||
cleaned := cleanJSONSchemaRecursive(schema)
|
||||
result, ok := cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractDefs 提取并移除定义的 helper
|
||||
func extractDefs(schema map[string]any) map[string]any {
|
||||
defs := make(map[string]any)
|
||||
if d, ok := schema["$defs"].(map[string]any); ok {
|
||||
for k, v := range d {
|
||||
defs[k] = v
|
||||
}
|
||||
delete(schema, "$defs")
|
||||
}
|
||||
if d, ok := schema["definitions"].(map[string]any); ok {
|
||||
for k, v := range d {
|
||||
defs[k] = v
|
||||
}
|
||||
delete(schema, "definitions")
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
// flattenRefs 递归展开 $ref
|
||||
func flattenRefs(schema map[string]any, defs map[string]any) {
|
||||
if len(defs) == 0 {
|
||||
return // 无需展开
|
||||
}
|
||||
|
||||
// 检查并替换 $ref
|
||||
if ref, ok := schema["$ref"].(string); ok {
|
||||
delete(schema, "$ref")
|
||||
// 解析引用名 (例如 #/$defs/MyType -> MyType)
|
||||
parts := strings.Split(ref, "/")
|
||||
refName := parts[len(parts)-1]
|
||||
|
||||
if defSchema, exists := defs[refName]; exists {
|
||||
if defMap, ok := defSchema.(map[string]any); ok {
|
||||
// 合并定义内容 (不覆盖现有 key)
|
||||
for k, v := range defMap {
|
||||
if _, has := schema[k]; !has {
|
||||
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
|
||||
}
|
||||
}
|
||||
// 递归处理刚刚合并进来的内容
|
||||
flattenRefs(schema, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历子节点
|
||||
for _, v := range schema {
|
||||
if subMap, ok := v.(map[string]any); ok {
|
||||
flattenRefs(subMap, defs)
|
||||
} else if subArr, ok := v.([]any); ok {
|
||||
for _, item := range subArr {
|
||||
if itemMap, ok := item.(map[string]any); ok {
|
||||
flattenRefs(itemMap, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
|
||||
func deepCopy(src any) any {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case map[string]any:
|
||||
dst := make(map[string]any)
|
||||
for k, val := range v {
|
||||
dst[k] = deepCopy(val)
|
||||
}
|
||||
return dst
|
||||
case []any:
|
||||
dst := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
dst[i] = deepCopy(val)
|
||||
}
|
||||
return dst
|
||||
default:
|
||||
return src
|
||||
}
|
||||
}
|
||||
|
||||
// cleanJSONSchemaRecursive 递归核心清理逻辑
|
||||
// 返回处理后的值 (通常是 input map,但可能修改内部结构)
|
||||
func cleanJSONSchemaRecursive(value any) any {
|
||||
schemaMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return value
|
||||
}
|
||||
|
||||
// 0. [NEW] 合并 allOf
|
||||
mergeAllOf(schemaMap)
|
||||
|
||||
// 1. [CRITICAL] 深度递归处理子项
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||
for _, v := range props {
|
||||
cleanJSONSchemaRecursive(v)
|
||||
}
|
||||
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
|
||||
// 因为我们在子项处理中会正确设置 type 和 description
|
||||
} else if items, ok := schemaMap["items"]; ok {
|
||||
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
|
||||
if itemsArr, ok := items.([]any); ok {
|
||||
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
|
||||
best := extractBestSchemaFromUnion(itemsArr)
|
||||
if best == nil {
|
||||
// 回退到通用字符串
|
||||
best = map[string]any{"type": "string"}
|
||||
}
|
||||
// 用处理后的对象替换原有数组
|
||||
cleanedBest := cleanJSONSchemaRecursive(best)
|
||||
schemaMap["items"] = cleanedBest
|
||||
} else {
|
||||
cleanJSONSchemaRecursive(items)
|
||||
}
|
||||
} else {
|
||||
// 遍历所有值递归
|
||||
for _, v := range schemaMap {
|
||||
if _, isMap := v.(map[string]any); isMap {
|
||||
cleanJSONSchemaRecursive(v)
|
||||
} else if arr, isArr := v.([]any); isArr {
|
||||
for _, item := range arr {
|
||||
cleanJSONSchemaRecursive(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
|
||||
var unionArray []any
|
||||
typeStr, _ := schemaMap["type"].(string)
|
||||
if typeStr == "" || typeStr == "object" {
|
||||
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
|
||||
unionArray = anyOf
|
||||
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
|
||||
unionArray = oneOf
|
||||
}
|
||||
}
|
||||
|
||||
if len(unionArray) > 0 {
|
||||
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
|
||||
if bestMap, ok := bestBranch.(map[string]any); ok {
|
||||
// 合并分支内容
|
||||
for k, v := range bestMap {
|
||||
if k == "properties" {
|
||||
targetProps, _ := schemaMap["properties"].(map[string]any)
|
||||
if targetProps == nil {
|
||||
targetProps = make(map[string]any)
|
||||
schemaMap["properties"] = targetProps
|
||||
}
|
||||
if sourceProps, ok := v.(map[string]any); ok {
|
||||
for pk, pv := range sourceProps {
|
||||
if _, exists := targetProps[pk]; !exists {
|
||||
targetProps[pk] = deepCopy(pv)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if k == "required" {
|
||||
targetReq, _ := schemaMap["required"].([]any)
|
||||
if sourceReq, ok := v.([]any); ok {
|
||||
for _, rv := range sourceReq {
|
||||
// 简单的去重添加
|
||||
exists := false
|
||||
for _, tr := range targetReq {
|
||||
if tr == rv {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
targetReq = append(targetReq, rv)
|
||||
}
|
||||
}
|
||||
schemaMap["required"] = targetReq
|
||||
}
|
||||
} else if _, exists := schemaMap[k]; !exists {
|
||||
schemaMap[k] = deepCopy(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
|
||||
looksLikeSchema := hasKey(schemaMap, "type") ||
|
||||
hasKey(schemaMap, "properties") ||
|
||||
hasKey(schemaMap, "items") ||
|
||||
hasKey(schemaMap, "enum") ||
|
||||
hasKey(schemaMap, "anyOf") ||
|
||||
hasKey(schemaMap, "oneOf") ||
|
||||
hasKey(schemaMap, "allOf")
|
||||
|
||||
if looksLikeSchema {
|
||||
// 4. [ROBUST] 约束迁移
|
||||
migrateConstraints(schemaMap)
|
||||
|
||||
// 5. [CRITICAL] 白名单过滤
|
||||
allowedFields := map[string]bool{
|
||||
"type": true,
|
||||
"description": true,
|
||||
"properties": true,
|
||||
"required": true,
|
||||
"items": true,
|
||||
"enum": true,
|
||||
"title": true,
|
||||
}
|
||||
for k := range schemaMap {
|
||||
if !allowedFields[k] {
|
||||
delete(schemaMap, k)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. [SAFETY] 处理空 Object
|
||||
if t, _ := schemaMap["type"].(string); t == "object" {
|
||||
hasProps := false
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
|
||||
hasProps = true
|
||||
}
|
||||
if !hasProps {
|
||||
schemaMap["properties"] = map[string]any{
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Reason for calling this tool",
|
||||
},
|
||||
}
|
||||
schemaMap["required"] = []any{"reason"}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. [SAFETY] Required 字段对齐
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||
if req, ok := schemaMap["required"].([]any); ok {
|
||||
var validReq []any
|
||||
for _, r := range req {
|
||||
if rStr, ok := r.(string); ok {
|
||||
if _, exists := props[rStr]; exists {
|
||||
validReq = append(validReq, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(validReq) > 0 {
|
||||
schemaMap["required"] = validReq
|
||||
} else {
|
||||
delete(schemaMap, "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
|
||||
isEffectivelyNullable := false
|
||||
if typeVal, exists := schemaMap["type"]; exists {
|
||||
var selectedType string
|
||||
switch v := typeVal.(type) {
|
||||
case string:
|
||||
lower := strings.ToLower(v)
|
||||
if lower == "null" {
|
||||
isEffectivelyNullable = true
|
||||
selectedType = "string" // fallback
|
||||
} else {
|
||||
selectedType = lower
|
||||
}
|
||||
case []any:
|
||||
// ["string", "null"]
|
||||
for _, t := range v {
|
||||
if ts, ok := t.(string); ok {
|
||||
lower := strings.ToLower(ts)
|
||||
if lower == "null" {
|
||||
isEffectivelyNullable = true
|
||||
} else if selectedType == "" {
|
||||
selectedType = lower
|
||||
}
|
||||
}
|
||||
}
|
||||
if selectedType == "" {
|
||||
selectedType = "string"
|
||||
}
|
||||
}
|
||||
schemaMap["type"] = selectedType
|
||||
} else {
|
||||
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
|
||||
// 如果没有 type,但有 properties,补一个
|
||||
if hasKey(schemaMap, "properties") {
|
||||
schemaMap["type"] = "object"
|
||||
} else {
|
||||
// 默认为 string ? or object? Gemini 通常需要明确 type
|
||||
schemaMap["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
if isEffectivelyNullable {
|
||||
desc, _ := schemaMap["description"].(string)
|
||||
if !strings.Contains(desc, "nullable") {
|
||||
if desc != "" {
|
||||
desc += " "
|
||||
}
|
||||
desc += "(nullable)"
|
||||
schemaMap["description"] = desc
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Enum 值强制转字符串
|
||||
if enumVals, ok := schemaMap["enum"].([]any); ok {
|
||||
hasNonString := false
|
||||
for i, val := range enumVals {
|
||||
if _, isStr := val.(string); !isStr {
|
||||
hasNonString = true
|
||||
if val == nil {
|
||||
enumVals[i] = "null"
|
||||
} else {
|
||||
enumVals[i] = fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we mandated string values, we must ensure type is string
|
||||
if hasNonString {
|
||||
schemaMap["type"] = "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
func hasKey(m map[string]any, k string) bool {
|
||||
_, ok := m[k]
|
||||
return ok
|
||||
}
|
||||
|
||||
func migrateConstraints(m map[string]any) {
|
||||
constraints := []struct {
|
||||
key string
|
||||
label string
|
||||
}{
|
||||
{"minLength", "minLen"},
|
||||
{"maxLength", "maxLen"},
|
||||
{"pattern", "pattern"},
|
||||
{"minimum", "min"},
|
||||
{"maximum", "max"},
|
||||
{"multipleOf", "multipleOf"},
|
||||
{"exclusiveMinimum", "exclMin"},
|
||||
{"exclusiveMaximum", "exclMax"},
|
||||
{"minItems", "minItems"},
|
||||
{"maxItems", "maxItems"},
|
||||
{"propertyNames", "propertyNames"},
|
||||
{"format", "format"},
|
||||
}
|
||||
|
||||
var hints []string
|
||||
for _, c := range constraints {
|
||||
if val, ok := m[c.key]; ok && val != nil {
|
||||
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
|
||||
}
|
||||
}
|
||||
|
||||
if len(hints) > 0 {
|
||||
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
|
||||
desc, _ := m["description"].(string)
|
||||
if !strings.Contains(desc, suffix) {
|
||||
m["description"] = desc + suffix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mergeAllOf 合并 allOf
|
||||
func mergeAllOf(m map[string]any) {
|
||||
allOf, ok := m["allOf"].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(m, "allOf")
|
||||
|
||||
mergedProps := make(map[string]any)
|
||||
mergedReq := make(map[string]bool)
|
||||
otherFields := make(map[string]any)
|
||||
|
||||
for _, sub := range allOf {
|
||||
if subMap, ok := sub.(map[string]any); ok {
|
||||
// Props
|
||||
if props, ok := subMap["properties"].(map[string]any); ok {
|
||||
for k, v := range props {
|
||||
mergedProps[k] = v
|
||||
}
|
||||
}
|
||||
// Required
|
||||
if reqs, ok := subMap["required"].([]any); ok {
|
||||
for _, r := range reqs {
|
||||
if s, ok := r.(string); ok {
|
||||
mergedReq[s] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Others
|
||||
for k, v := range subMap {
|
||||
if k != "properties" && k != "required" && k != "allOf" {
|
||||
if _, exists := otherFields[k]; !exists {
|
||||
otherFields[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply
|
||||
for k, v := range otherFields {
|
||||
if _, exists := m[k]; !exists {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
if len(mergedProps) > 0 {
|
||||
existProps, _ := m["properties"].(map[string]any)
|
||||
if existProps == nil {
|
||||
existProps = make(map[string]any)
|
||||
m["properties"] = existProps
|
||||
}
|
||||
for k, v := range mergedProps {
|
||||
if _, exists := existProps[k]; !exists {
|
||||
existProps[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(mergedReq) > 0 {
|
||||
existReq, _ := m["required"].([]any)
|
||||
var validReqs []any
|
||||
for _, r := range existReq {
|
||||
if s, ok := r.(string); ok {
|
||||
validReqs = append(validReqs, s)
|
||||
delete(mergedReq, s) // already exists
|
||||
}
|
||||
}
|
||||
// append new
|
||||
for r := range mergedReq {
|
||||
validReqs = append(validReqs, r)
|
||||
}
|
||||
m["required"] = validReqs
|
||||
}
|
||||
}
|
||||
|
||||
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
|
||||
func extractBestSchemaFromUnion(unionArray []any) any {
|
||||
var bestOption any
|
||||
bestScore := -1
|
||||
|
||||
for _, item := range unionArray {
|
||||
score := scoreSchemaOption(item)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestOption = item
|
||||
}
|
||||
}
|
||||
return bestOption
|
||||
}
|
||||
|
||||
func scoreSchemaOption(val any) int {
|
||||
m, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
typeStr, _ := m["type"].(string)
|
||||
|
||||
if hasKey(m, "properties") || typeStr == "object" {
|
||||
return 3
|
||||
}
|
||||
if hasKey(m, "items") || typeStr == "array" {
|
||||
return 2
|
||||
}
|
||||
if typeStr != "" && typeStr != "null" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
|
||||
func DeepCleanUndefined(value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
for k, val := range v {
|
||||
if s, ok := val.(string); ok && s == "[undefined]" {
|
||||
delete(v, k)
|
||||
continue
|
||||
}
|
||||
DeepCleanUndefined(val)
|
||||
}
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
DeepCleanUndefined(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
|
||||
if geminiResp.Candidates[0].Content != nil {
|
||||
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
if finishReason != "" {
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
|
||||
@@ -13,20 +13,26 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
// Claude OAuth Constants
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Setup token (inference only)
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Code Verifier character set (RFC 7636 compliant)
|
||||
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
// GenerateState generates a random state string for OAuth (base64url encoded)
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
const targetLen = 32
|
||||
charsetLen := len(codeVerifierCharset)
|
||||
limit := 256 - (256 % charsetLen)
|
||||
|
||||
result := make([]byte, 0, targetLen)
|
||||
randBuf := make([]byte, targetLen*2)
|
||||
|
||||
for len(result) < targetLen {
|
||||
if _, err := rand.Read(randBuf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, b := range randBuf {
|
||||
if int(b) < limit {
|
||||
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||
if len(result) >= targetLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
|
||||
return base64URLEncode(result), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
encodedRedirectURI := url.QueryEscape(RedirectURI)
|
||||
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
|
||||
AuthorizeURL,
|
||||
ClientID,
|
||||
encodedRedirectURI,
|
||||
encodedScope,
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
@@ -205,33 +215,6 @@ type OrgInfo struct {
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
//go:build integration
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests to external services and should be run manually.
|
||||
//
|
||||
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// skipIfExternalServiceUnavailable checks if the external service is available.
|
||||
// If not, it skips the test instead of failing.
|
||||
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
// Check for common network/TLS errors that indicate external service issues
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "certificate has expired") ||
|
||||
strings.Contains(errStr, "certificate is not yet valid") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "no such host") ||
|
||||
strings.Contains(errStr, "network is unreachable") ||
|
||||
strings.Contains(errStr, "timeout") {
|
||||
t.Skipf("skipping test: external service unavailable: %v", err)
|
||||
}
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
@@ -1,21 +1,16 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests and should be run manually.
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
// Integration tests that require external network are in dialer_integration_test.go
|
||||
// and require the 'integration' build tag.
|
||||
//
|
||||
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
|
||||
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
@@ -36,148 +31,6 @@ type TLSInfo struct {
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
|
||||
func TestDialerBasicConnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping network test in short mode")
|
||||
}
|
||||
|
||||
// Create a dialer with default profile
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
// Create HTTP client with custom TLS dialer
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Make a request to a known HTTPS endpoint
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
func TestDialerWithProfile(t *testing.T) {
|
||||
// Create two dialers with different profiles
|
||||
|
||||
@@ -39,9 +39,15 @@ import (
|
||||
// 设计说明:
|
||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
|
||||
type accountRepository struct {
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
|
||||
// 确保粘性会话能及时感知账号不可用状态。
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB)
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
|
||||
}
|
||||
|
||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||
// 这种设计便于单元测试时注入 mock 对象。
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq}
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
|
||||
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
|
||||
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
|
||||
//
|
||||
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
|
||||
// when account status changes. Called when account is set to error, disabled,
|
||||
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
|
||||
// logic can promptly detect the latest account state and avoid using unavailable accounts.
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
shouldSync = true
|
||||
}
|
||||
if updates.Schedulable != nil && !*updates.Schedulable {
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
disabled := service.StatusDisabled
|
||||
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||
Status: &disabled,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), rows)
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||
ids := map[int64]struct{}{}
|
||||
for _, acc := range cacheRecorder.setAccounts {
|
||||
ids[acc.ID] = struct{}{}
|
||||
}
|
||||
s.Require().Contains(ids, account1.ID)
|
||||
s.Require().Contains(ids, account2.ID)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
95
backend/internal/repository/aes_encryptor.go
Normal file
95
backend/internal/repository/aes_encryptor.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||
type AESEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewAESEncryptor creates a new AES encryptor
|
||||
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||
}
|
||||
|
||||
return &AESEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM
|
||||
// Output format: base64(nonce + ciphertext + tag)
|
||||
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Seal appends the ciphertext and tag to the nonce
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode as base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
|
||||
@@ -14,37 +14,82 @@ import (
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
// 默认 User-Agent,与用户抓包的请求一致
|
||||
const defaultUsageUserAgent = "claude-code/2.1.7"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
httpUpstream service.HTTPUpstream
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
|
||||
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
|
||||
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{
|
||||
usageURL: defaultClaudeUsageURL,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
|
||||
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("options is nil")
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
// 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
// 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值)
|
||||
userAgent := defaultUsageUserAgent
|
||||
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
|
||||
userAgent = opts.Fingerprint.UserAgent
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
// 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS
|
||||
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
|
||||
// accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置
|
||||
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 不启用 TLS 指纹,使用普通 HTTP 客户端
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: opts.ProxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -9,13 +9,27 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
const (
|
||||
verifyCodeKeyPrefix = "verify_code:"
|
||||
passwordResetKeyPrefix = "password_reset:"
|
||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset token methods
|
||||
|
||||
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||
key := passwordResetKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.PasswordResetTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||
key := passwordResetKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
key := passwordResetKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset email cooldown methods
|
||||
|
||||
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
key := passwordResetSentAtKey(email)
|
||||
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
key := passwordResetSentAtKey(email)
|
||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
|
||||
@@ -2,10 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||
client := NewOpenAIOAuthClient()
|
||||
svc, ok := client.(*openaiOAuthService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.ForceHTTP2 {
|
||||
client = client.EnableForceHTTP2()
|
||||
}
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
return fmt.Sprintf("%s|%s|%t|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
90
backend/internal/repository/req_client_pool_test.go
Normal file
90
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||
t.Helper()
|
||||
transport := client.GetTransport()
|
||||
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
base := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: time.Second,
|
||||
}
|
||||
clientDefault := getSharedReqClient(base)
|
||||
|
||||
force := base
|
||||
force.ForceHTTP2 = true
|
||||
clientForce := getSharedReqClient(force)
|
||||
|
||||
require.NotSame(t, clientDefault, clientForce)
|
||||
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
first := getSharedReqClient(opts)
|
||||
second := getSharedReqClient(opts)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
key := buildReqClientKey(opts)
|
||||
sharedReqClients.Store(key, "invalid")
|
||||
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
loaded, ok := sharedReqClients.Load(key)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, "invalid", loaded)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 4 * time.Second,
|
||||
Impersonate: true,
|
||||
}
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||
}
|
||||
|
||||
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createGeminiReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
// 空快照视为缓存未命中,触发数据库回退查询
|
||||
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]int), nil
|
||||
}
|
||||
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
|
||||
|
||||
// 使用 pipeline 批量执行
|
||||
pipe := c.rdb.Pipeline()
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
key := sessionLimitKey(accountID)
|
||||
// 使用各账号自己的 idleTimeout,如果没有则用默认值
|
||||
idleTimeout := c.defaultIdleTimeout
|
||||
if idleTimeouts != nil {
|
||||
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
|
||||
idleTimeout = t
|
||||
}
|
||||
}
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||
}
|
||||
|
||||
|
||||
149
backend/internal/repository/totp_cache.go
Normal file
149
backend/internal/repository/totp_cache.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
totpSetupKeyPrefix = "totp:setup:"
|
||||
totpLoginKeyPrefix = "totp:login:"
|
||||
totpAttemptsKeyPrefix = "totp:attempts:"
|
||||
totpAttemptsTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// TotpCache implements service.TotpCache using Redis
|
||||
type TotpCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTotpCache creates a new TOTP cache
|
||||
func NewTotpCache(rdb *redis.Client) service.TotpCache {
|
||||
return &TotpCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// GetSetupSession retrieves a TOTP setup session
|
||||
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get setup session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpSetupSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal setup session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetSetupSession stores a TOTP setup session
|
||||
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal setup session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set setup session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSetupSession deletes a TOTP setup session
|
||||
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetLoginSession retrieves a TOTP login session
|
||||
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get login session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpLoginSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal login session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetLoginSession stores a TOTP login session
|
||||
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal login session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set login session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteLoginSession deletes a TOTP login session
|
||||
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// IncrementVerifyAttempts increments the verify attempt counter
|
||||
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
|
||||
// Use pipeline for atomic increment and set TTL
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, totpAttemptsTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("increment verify attempts: %w", err)
|
||||
}
|
||||
|
||||
count, err := incrCmd.Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get increment result: %w", err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// GetVerifyAttempts gets the current verify attempt count
|
||||
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("get verify attempts: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ClearVerifyAttempts clears the verify attempt counter
|
||||
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
update := client.User.UpdateOneID(userID)
|
||||
if encryptedSecret == nil {
|
||||
update = update.ClearTotpSecretEncrypted()
|
||||
} else {
|
||||
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableTotp 启用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(true).
|
||||
SetTotpEnabledAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(false).
|
||||
ClearTotpEnabledAt().
|
||||
ClearTotpSecretEncrypted().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
q := client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
if groupID != nil {
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if status != "" {
|
||||
|
||||
// Status filtering with real-time expiration check
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case service.SubscriptionStatusActive:
|
||||
// Active: status is active AND not yet expired
|
||||
q = q.Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(now),
|
||||
)
|
||||
case service.SubscriptionStatusExpired:
|
||||
// Expired: status is expired OR (status is active but already expired)
|
||||
q = q.Where(
|
||||
usersubscription.Or(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
|
||||
usersubscription.And(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(now),
|
||||
),
|
||||
),
|
||||
)
|
||||
case "":
|
||||
// No filter
|
||||
default:
|
||||
// Other status (e.g., revoked)
|
||||
q = q.Where(usersubscription.StatusEQ(status))
|
||||
}
|
||||
|
||||
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
q = q.WithUser().WithGroup().WithAssignedByUser()
|
||||
|
||||
// Determine sort field
|
||||
var field string
|
||||
switch sortBy {
|
||||
case "expires_at":
|
||||
field = usersubscription.FieldExpiresAt
|
||||
case "status":
|
||||
field = usersubscription.FieldStatus
|
||||
default:
|
||||
field = usersubscription.FieldCreatedAt
|
||||
}
|
||||
|
||||
// Determine sort order (default: desc)
|
||||
if sortOrder == "asc" && sortBy != "" {
|
||||
q = q.Order(dbent.Asc(field))
|
||||
} else {
|
||||
q = q.Order(dbent.Desc(field))
|
||||
}
|
||||
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
|
||||
@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
|
||||
"id": 1,
|
||||
"email": "alice@example.com",
|
||||
"username": "alice",
|
||||
"notes": "hello",
|
||||
"role": "user",
|
||||
"balance": 12.5,
|
||||
"concurrency": 5,
|
||||
@@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/groups/available",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
|
||||
deps.groupRepo.SetActive([]service.Group{
|
||||
{
|
||||
ID: 10,
|
||||
Name: "Group One",
|
||||
Description: "desc",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.5,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-3-*": []int64{101, 102},
|
||||
},
|
||||
AccountCount: 2,
|
||||
CreatedAt: deps.now,
|
||||
UpdatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
deps.userSubRepo.SetActiveByUserID(1, nil)
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/groups/available",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 10,
|
||||
"name": "Group One",
|
||||
"description": "desc",
|
||||
"platform": "anthropic",
|
||||
"rate_multiplier": 1.5,
|
||||
"is_exclusive": false,
|
||||
"status": "active",
|
||||
"subscription_type": "standard",
|
||||
"daily_limit_usd": null,
|
||||
"weekly_limit_usd": null,
|
||||
"monthly_limit_usd": null,
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"claude_code_only": false,
|
||||
"fallback_group_id": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/subscriptions",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
|
||||
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
|
||||
{
|
||||
ID: 501,
|
||||
UserID: 1,
|
||||
GroupID: 10,
|
||||
StartsAt: deps.now,
|
||||
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
|
||||
Status: service.SubscriptionStatusActive,
|
||||
DailyUsageUSD: 1.23,
|
||||
WeeklyUsageUSD: 2.34,
|
||||
MonthlyUsageUSD: 3.45,
|
||||
AssignedBy: ptr(int64(999)),
|
||||
AssignedAt: deps.now,
|
||||
Notes: "admin-note",
|
||||
CreatedAt: deps.now,
|
||||
UpdatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/subscriptions",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 501,
|
||||
"user_id": 1,
|
||||
"group_id": 10,
|
||||
"starts_at": "2025-01-02T03:04:05Z",
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"status": "active",
|
||||
"daily_window_start": null,
|
||||
"weekly_window_start": null,
|
||||
"monthly_window_start": null,
|
||||
"daily_usage_usd": 1.23,
|
||||
"weekly_usage_usd": 2.34,
|
||||
"monthly_usage_usd": 3.45,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/redeem/history",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户兑换历史不应包含 notes 等内部字段。
|
||||
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
|
||||
{
|
||||
ID: 900,
|
||||
Code: "CODE-123",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 1.25,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: ptr(int64(1)),
|
||||
UsedAt: ptr(deps.now),
|
||||
Notes: "internal-note",
|
||||
CreatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/redeem/history",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 900,
|
||||
"code": "CODE-123",
|
||||
"type": "balance",
|
||||
"value": 1.25,
|
||||
"status": "used",
|
||||
"used_by": 1,
|
||||
"used_at": "2025-01-02T03:04:05Z",
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"group_id": null,
|
||||
"validity_days": 0
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/usage/stats",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
@@ -190,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
|
||||
t.Helper()
|
||||
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
RateMultiplier: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
Stream: true,
|
||||
DurationMs: ptr(100),
|
||||
FirstTokenMs: ptr(50),
|
||||
CreatedAt: deps.now,
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
AccountRateMultiplier: ptr(0.5),
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
RateMultiplier: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
Stream: true,
|
||||
DurationMs: ptr(100),
|
||||
FirstTokenMs: ptr(50),
|
||||
CreatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
@@ -238,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"output_cost": 0,
|
||||
"cache_creation_cost": 0,
|
||||
"cache_read_cost": 0,
|
||||
"total_cost": 0.5,
|
||||
"total_cost": 0.5,
|
||||
"actual_cost": 0.5,
|
||||
"rate_multiplier": 1,
|
||||
"account_rate_multiplier": null,
|
||||
"billing_type": 0,
|
||||
"stream": true,
|
||||
"duration_ms": 100,
|
||||
@@ -266,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -304,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
@@ -337,7 +488,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"home_content": ""
|
||||
"home_content": "",
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": ""
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -385,8 +539,11 @@ type contractDeps struct {
|
||||
now time.Time
|
||||
router http.Handler
|
||||
apiKeyRepo *stubApiKeyRepo
|
||||
groupRepo *stubGroupRepo
|
||||
userSubRepo *stubUserSubscriptionRepo
|
||||
usageRepo *stubUsageLogRepo
|
||||
settingRepo *stubSettingRepo
|
||||
redeemRepo *stubRedeemCodeRepo
|
||||
}
|
||||
|
||||
func newContractDeps(t *testing.T) *contractDeps {
|
||||
@@ -414,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
|
||||
apiKeyRepo := newStubApiKeyRepo(now)
|
||||
apiKeyCache := stubApiKeyCache{}
|
||||
groupRepo := stubGroupRepo{}
|
||||
userSubRepo := stubUserSubscriptionRepo{}
|
||||
groupRepo := &stubGroupRepo{}
|
||||
userSubRepo := &stubUserSubscriptionRepo{}
|
||||
accountRepo := stubAccountRepo{}
|
||||
proxyRepo := stubProxyRepo{}
|
||||
redeemRepo := stubRedeemCodeRepo{}
|
||||
redeemRepo := &stubRedeemCodeRepo{}
|
||||
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
@@ -433,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
@@ -472,12 +635,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
v1Keys.Use(jwtAuth)
|
||||
v1Keys.GET("/keys", apiKeyHandler.List)
|
||||
v1Keys.POST("/keys", apiKeyHandler.Create)
|
||||
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
|
||||
|
||||
v1Usage := v1.Group("")
|
||||
v1Usage.Use(jwtAuth)
|
||||
v1Usage.GET("/usage", usageHandler.List)
|
||||
v1Usage.GET("/usage/stats", usageHandler.Stats)
|
||||
|
||||
v1Subs := v1.Group("")
|
||||
v1Subs.Use(jwtAuth)
|
||||
v1Subs.GET("/subscriptions", subscriptionHandler.List)
|
||||
|
||||
v1Redeem := v1.Group("")
|
||||
v1Redeem.Use(jwtAuth)
|
||||
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
|
||||
|
||||
v1Admin := v1.Group("/admin")
|
||||
v1Admin.Use(adminAuth)
|
||||
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
||||
@@ -487,8 +659,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
now: now,
|
||||
router: r,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
usageRepo: usageRepo,
|
||||
settingRepo: settingRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -584,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubApiKeyCache struct{}
|
||||
|
||||
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
@@ -626,7 +813,13 @@ func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handl
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGroupRepo struct{}
|
||||
type stubGroupRepo struct {
|
||||
active []service.Group
|
||||
}
|
||||
|
||||
func (r *stubGroupRepo) SetActive(groups []service.Group) {
|
||||
r.active = append([]service.Group(nil), groups...)
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -660,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
return append([]service.Group(nil), r.active...), nil
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
out := make([]service.Group, 0, len(r.active))
|
||||
for i := range r.active {
|
||||
g := r.active[i]
|
||||
if g.Platform == platform {
|
||||
out = append(out, g)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
@@ -883,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubRedeemCodeRepo struct{}
|
||||
type stubRedeemCodeRepo struct {
|
||||
byUser map[int64][]service.RedeemCode
|
||||
}
|
||||
|
||||
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
|
||||
if r.byUser == nil {
|
||||
r.byUser = make(map[int64][]service.RedeemCode)
|
||||
}
|
||||
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -921,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
if r.byUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
codes := r.byUser[userID]
|
||||
if limit > 0 && len(codes) > limit {
|
||||
codes = codes[:limit]
|
||||
}
|
||||
return append([]service.RedeemCode(nil), codes...), nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct{}
|
||||
type stubUserSubscriptionRepo struct {
|
||||
byUser map[int64][]service.UserSubscription
|
||||
activeByUser map[int64][]service.UserSubscription
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
|
||||
if r.byUser == nil {
|
||||
r.byUser = make(map[int64][]service.UserSubscription)
|
||||
}
|
||||
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
|
||||
if r.activeByUser == nil {
|
||||
r.activeByUser = make(map[int64][]service.UserSubscription)
|
||||
}
|
||||
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||
}
|
||||
|
||||
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -945,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
|
||||
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
if r.byUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
if r.activeByUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
|
||||
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ValidatePromoCode)
|
||||
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
||||
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ForgotPassword)
|
||||
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
|
||||
// TOTP 双因素认证
|
||||
totp := user.Group("/totp")
|
||||
{
|
||||
totp.GET("/status", h.Totp.GetStatus)
|
||||
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
|
||||
totp.POST("/send-code", h.Totp.SendVerifyCode)
|
||||
totp.POST("/setup", h.Totp.InitiateSetup)
|
||||
totp.POST("/enable", h.Totp.Enable)
|
||||
totp.POST("/disable", h.Totp.Disable)
|
||||
}
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
|
||||
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredentialAsInt64 解析凭证中的 int64 字段
|
||||
// 用于读取 _token_version 等内部字段
|
||||
func (a *Account) GetCredentialAsInt64(key string) int64 {
|
||||
if a == nil || a.Credentials == nil {
|
||||
return 0
|
||||
}
|
||||
val, ok := a.Credentials[key]
|
||||
if !ok || val == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case int:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return i
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
|
||||
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
|
||||
type ClaudeUsageFetchOptions struct {
|
||||
AccessToken string // OAuth access token
|
||||
ProxyURL string // 代理 URL(可选)
|
||||
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
|
||||
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
|
||||
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
|
||||
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
@@ -170,6 +181,7 @@ type AccountUsageService struct {
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -180,6 +192,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
identityCache IdentityCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -188,6 +201,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
identityCache: identityCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
|
||||
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
@@ -435,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
// 构建完整的选项
|
||||
opts := &ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
AccountID: account.ID,
|
||||
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
|
||||
}
|
||||
|
||||
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
|
||||
if s.identityCache != nil {
|
||||
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
|
||||
opts.Fingerprint = fp
|
||||
}
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
|
||||
@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 清理 Schema
|
||||
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||||
injectedBody = cleanedBody
|
||||
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||||
} else {
|
||||
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||
if err != nil {
|
||||
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
// Check for MALFORMED_FUNCTION_CALL
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
usage = u
|
||||
}
|
||||
|
||||
// Check for MALFORMED_FUNCTION_CALL
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保留最后一个有 parts 的响应
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
|
||||
return result, existingParts, setParts
|
||||
}
|
||||
|
||||
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||||
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||||
// 保持原始顺序,只合并连续的普通 text parts
|
||||
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||||
if len(collectedParts) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
result, _, setParts := getOrCreateGeminiParts(response)
|
||||
|
||||
// 合并策略:
|
||||
// 1. 保持原始顺序
|
||||
// 2. 连续的普通 text parts 合并为一个
|
||||
// 3. thinking、functionCall、inlineData 等保持原样
|
||||
var mergedParts []any
|
||||
var textBuffer strings.Builder
|
||||
|
||||
flushTextBuffer := func() {
|
||||
if textBuffer.Len() > 0 {
|
||||
mergedParts = append(mergedParts, map[string]any{
|
||||
"text": textBuffer.String(),
|
||||
})
|
||||
textBuffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range collectedParts {
|
||||
// 检查是否是普通 text part
|
||||
if text, ok := part["text"].(string); ok {
|
||||
// 检查是否有 thought 标记
|
||||
if thought, _ := part["thought"].(bool); thought {
|
||||
// thinking part,先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
} else {
|
||||
// 普通 text,累积到 buffer
|
||||
_, _ = textBuffer.WriteString(text)
|
||||
}
|
||||
} else {
|
||||
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新剩余的 text
|
||||
flushTextBuffer()
|
||||
|
||||
setParts(mergedParts)
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||||
if len(imageParts) == 0 {
|
||||
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
var firstTokenMs *int
|
||||
var last map[string]any
|
||||
var lastWithParts map[string]any
|
||||
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
last = parsed
|
||||
|
||||
// 保留最后一个有 parts 的响应
|
||||
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
|
||||
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||||
collectedParts = append(collectedParts, parts...)
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -2252,6 +2343,11 @@ returnResponse:
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||
}
|
||||
|
||||
// 将收集的所有 parts 合并到最终响应中
|
||||
if len(collectedParts) > 0 {
|
||||
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||||
}
|
||||
|
||||
// 序列化为 JSON(Gemini 格式)
|
||||
geminiBody, err := json.Marshal(finalResponse)
|
||||
if err != nil {
|
||||
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
|
||||
modelLower == "gemini-2.5-flash-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||||
}
|
||||
|
||||
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
|
||||
func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modified := false
|
||||
|
||||
// 1. 清理 Tools
|
||||
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
|
||||
for _, t := range tools {
|
||||
toolMap, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// function_declarations (snake_case) or functionDeclarations (camelCase)
|
||||
var funcs []any
|
||||
if f, ok := toolMap["functionDeclarations"].([]any); ok {
|
||||
funcs = f
|
||||
} else if f, ok := toolMap["function_declarations"].([]any); ok {
|
||||
funcs = f
|
||||
}
|
||||
|
||||
if len(funcs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, f := range funcs {
|
||||
funcMap, ok := f.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if params, ok := funcMap["parameters"].(map[string]any); ok {
|
||||
antigravity.DeepCleanUndefined(params)
|
||||
cleaned := antigravity.CleanJSONSchema(params)
|
||||
funcMap["parameters"] = cleaned
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(部分账户类型可能没有)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
// 获取 project_id(部分账户类型可能没有),失败时重试
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
result.ProjectIDMissing = true
|
||||
} else {
|
||||
result.ProjectID = projectID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
|
||||
if loadErr != nil {
|
||||
// LoadCodeAssist 失败,保留原有 project_id
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
|
||||
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
|
||||
if existingProjectID == "" {
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
}
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// loadProjectIDWithRetry 带重试机制获取 project_id
|
||||
// 返回 project_id 和错误,失败时会重试指定次数
|
||||
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// 指数退避:1s, 2s, 4s
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 8*time.Second {
|
||||
backoff = 8 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if loadResp == nil {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
|
||||
} else {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
|
||||
var handleErrorCalled bool
|
||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
handleErrorCalled = true
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,只记录警告,不返回错误
|
||||
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
|
||||
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
if tokenInfo.ProjectID != "" {
|
||||
// 有旧的 project_id,本次获取失败,保留旧值
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
|
||||
} else {
|
||||
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
|
||||
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
// 生成新token
|
||||
return s.GenerateToken(user)
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证且 SMTP 配置正确
|
||||
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
// Must have email verification enabled and SMTP configured
|
||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsPasswordResetEnabled(ctx)
|
||||
}
|
||||
|
||||
// preparePasswordReset validates the password reset request and returns necessary data
|
||||
// Returns (siteName, resetURL, shouldProceed)
|
||||
// shouldProceed is false when we should silently return success (to prevent enumeration)
|
||||
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
|
||||
// Check if user exists (but don't reveal this to the caller)
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// Security: Log but don't reveal that user doesn't exist
|
||||
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
log.Printf("[Auth] Database error checking email for password reset: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Get site name
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
// Build reset URL base
|
||||
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
|
||||
|
||||
return siteName, resetURL, true
|
||||
}
|
||||
|
||||
// RequestPasswordReset 请求密码重置(同步发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email sent to: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailQueueService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email enqueued for: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
|
||||
// Check if password reset is enabled
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Verify and consume the reset token (one-time use)
|
||||
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return ErrInvalidResetToken // Token was valid but user was deleted
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for password reset: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
return ErrUserNotActive
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password and increment TokenVersion
|
||||
user.PasswordHash = hashedPassword
|
||||
user.TokenVersion++ // Invalidate all existing tokens
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
|
||||
@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -69,8 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -86,6 +88,9 @@ const (
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// TOTP 双因素认证设置
|
||||
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
|
||||
|
||||
// LinuxDo Connect OAuth 登录设置
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
@@ -93,13 +98,16 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -8,11 +8,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task type constants
|
||||
const (
|
||||
TaskTypeVerifyCode = "verify_code"
|
||||
TaskTypePasswordReset = "password_reset"
|
||||
)
|
||||
|
||||
// EmailTask 邮件发送任务
|
||||
type EmailTask struct {
|
||||
Email string
|
||||
SiteName string
|
||||
TaskType string // "verify_code"
|
||||
TaskType string // "verify_code" or "password_reset"
|
||||
ResetURL string // Only used for password_reset task type
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
defer cancel()
|
||||
|
||||
switch task.TaskType {
|
||||
case "verify_code":
|
||||
case TaskTypeVerifyCode:
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
case TaskTypePasswordReset:
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: "verify_code",
|
||||
TaskType: TaskTypeVerifyCode,
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: TaskTypePasswordReset,
|
||||
ResetURL: resetURL,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止队列服务
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
|
||||
@@ -3,11 +3,14 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +22,9 @@ var (
|
||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||
|
||||
// Password reset errors
|
||||
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
|
||||
)
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
@@ -26,6 +32,16 @@ type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
|
||||
// Password reset token methods
|
||||
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
|
||||
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
|
||||
DeletePasswordResetToken(ctx context.Context, email string) error
|
||||
|
||||
// Password reset email cooldown methods
|
||||
// Returns true if in cooldown period (email was sent recently)
|
||||
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
|
||||
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// PasswordResetTokenData represents password reset token data
|
||||
type PasswordResetTokenData struct {
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
|
||||
// Password reset token settings
|
||||
passwordResetTokenTTL = 30 * time.Minute
|
||||
|
||||
// Password reset email cooldown (prevent email bombing)
|
||||
passwordResetEmailCooldown = 30 * time.Second
|
||||
)
|
||||
|
||||
// SMTPConfig SMTP配置
|
||||
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||
data.Attempts++
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
|
||||
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
||||
var token string
|
||||
var needSaveToken bool
|
||||
|
||||
// Check if token already exists
|
||||
existing, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
// Token exists, reuse it (allows resending email without generating new token)
|
||||
token = existing.Token
|
||||
needSaveToken = false
|
||||
} else {
|
||||
// Generate new token
|
||||
token, err = s.GeneratePasswordResetToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
needSaveToken = true
|
||||
}
|
||||
|
||||
// Save token to Redis (only if new token generated)
|
||||
if needSaveToken {
|
||||
data := &PasswordResetTokenData{
|
||||
Token: token,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
|
||||
return fmt.Errorf("save reset token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build full reset URL with URL-encoded token and email
|
||||
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||
|
||||
// Build email content
|
||||
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||
|
||||
// Send email
|
||||
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||
// Check email cooldown to prevent email bombing
|
||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
|
||||
return nil // Silent success to prevent revealing cooldown to attackers
|
||||
}
|
||||
|
||||
// Send email using core method
|
||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set cooldown marker (Redis TTL handles expiration)
|
||||
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
||||
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPasswordResetToken verifies the password reset token without consuming it
|
||||
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
|
||||
data, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
|
||||
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
|
||||
// Verify first
|
||||
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete after verification (one-time use)
|
||||
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPasswordResetEmailBody builds the HTML content for password reset email
|
||||
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||
.header h1 { margin: 0; font-size: 24px; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
|
||||
.button:hover { opacity: 0.9; }
|
||||
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
.warning { color: #e74c3c; font-weight: 500; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>%s</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p style="font-size: 18px; color: #333;">密码重置请求</p>
|
||||
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
|
||||
<a href="%s" class="button">重置密码</a>
|
||||
<div class="info">
|
||||
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
|
||||
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
|
||||
</div>
|
||||
<div class="link-fallback">
|
||||
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||
<p>%s</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>这是一封自动发送的邮件,请勿回复。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, siteName, resetURL, resetURL)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -99,11 +99,24 @@ var allowedHeaders = map[string]bool{
|
||||
"content-type": true,
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
// GetSessionAccountID 获取粘性会话绑定的账号 ID
|
||||
// Get the account ID bound to a sticky session
|
||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||
// SetSessionAccountID 设置粘性会话与账号的绑定关系
|
||||
// Set the binding between sticky session and account
|
||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
// RefreshSessionTTL 刷新粘性会话的过期时间
|
||||
// Refresh the expiration time of a sticky session
|
||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -114,6 +127,28 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// or within temporary unschedulable period.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
func shouldClearStickySession(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
|
||||
return true
|
||||
}
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
AccountID int64
|
||||
MaxConcurrency int
|
||||
@@ -270,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
|
||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
|
||||
// Returns 0 if no binding exists or on error.
|
||||
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
||||
if sessionHash == "" || s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
@@ -658,6 +706,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -764,41 +814,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
// Check if the account needs sticky session cleanup
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1418,14 +1479,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1515,11 +1582,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1619,15 +1692,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1718,12 +1797,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3287,17 +3372,19 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
||||
// output_tokens 总是从 message_delta 获取
|
||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||
|
||||
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
|
||||
if usage.InputTokens == 0 {
|
||||
// message_delta 仅覆盖存在且非0的字段
|
||||
// 避免覆盖 message_start 中已有的值(如 input_tokens)
|
||||
// Claude API 的 message_delta 通常只包含 output_tokens
|
||||
if msgDelta.Usage.InputTokens > 0 {
|
||||
usage.InputTokens = msgDelta.Usage.InputTokens
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 {
|
||||
if msgDelta.Usage.OutputTokens > 0 {
|
||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||
}
|
||||
if msgDelta.Usage.CacheCreationInputTokens > 0 {
|
||||
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
platform = PlatformGemini
|
||||
// 1. 确定目标平台和调度模式
|
||||
// Determine target platform and scheduling mode
|
||||
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
if account.Platform == platform {
|
||||
valid = true
|
||||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
valid = true
|
||||
}
|
||||
if valid {
|
||||
usable := true
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
usable = false
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 2. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// Query schedulable accounts (force platform mode: try group first, fallback to all)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
||||
// 非混合调度模式(antigravity 分组):不需要过滤
|
||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 4. 按优先级 + LRU 选择最佳账号
|
||||
// Select best account by priority + LRU
|
||||
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
}
|
||||
|
||||
// 5. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
}
|
||||
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
|
||||
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
|
||||
//
|
||||
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
|
||||
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
|
||||
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, false, true, nil
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, false, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
return group.Platform, group.Platform == PlatformGemini, false, nil
|
||||
}
|
||||
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
return PlatformGemini, true, false, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account unavailable.
|
||||
func (s *GeminiMessagesCompatService) tryStickySessionHit(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash, cacheKey, requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// isAccountUsableForRequest 检查账号是否可用于当前请求。
|
||||
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
|
||||
//
|
||||
// isAccountUsableForRequest checks if account is usable for current request.
|
||||
// Validates: model scheduling, model support, platform matching, rate limit precheck.
|
||||
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
) bool {
|
||||
// 检查模型调度能力
|
||||
// Check model scheduling capability
|
||||
if !account.IsSchedulableForModel(requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查平台匹配
|
||||
// Check platform matching
|
||||
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 速率限制预检
|
||||
// Rate limit precheck
|
||||
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isAccountValidForPlatform 检查账号是否匹配目标平台。
|
||||
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
|
||||
//
|
||||
// isAccountValidForPlatform checks if account matches target platform.
|
||||
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
|
||||
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
|
||||
if account.Platform == platform {
|
||||
return true
|
||||
}
|
||||
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// passesRateLimitPreCheck 执行速率限制预检。
|
||||
// 返回 true 表示通过预检或无需预检。
|
||||
//
|
||||
// passesRateLimitPreCheck performs rate limit precheck.
|
||||
// Returns true if passed or precheck not required.
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" {
|
||||
return true
|
||||
}
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
|
||||
// Returns nil if no available account.
|
||||
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
ctx context.Context,
|
||||
accounts []Account,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查账号是否可用于当前请求
|
||||
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择最佳账号
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterGeminiAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||
//
|
||||
// isBetterGeminiAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
|
||||
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
|
||||
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
@@ -800,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
imageCount := 0
|
||||
imageSize := s.extractImageSize(body)
|
||||
if isImageGenerationModel(originalModel) {
|
||||
imageCount = 1
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -807,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Stream: req.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1240,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
imageCount := 0
|
||||
imageSize := s.extractImageSize(body)
|
||||
if isImageGenerationModel(originalModel) {
|
||||
imageCount = 1
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -1247,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1841,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
|
||||
var last map[string]any
|
||||
var lastWithParts map[string]any
|
||||
var collectedTextParts []string // Collect all text parts for aggregation
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
for {
|
||||
@@ -1852,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
switch payload {
|
||||
case "", "[DONE]":
|
||||
if payload == "[DONE]" {
|
||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||||
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||
}
|
||||
default:
|
||||
var parsed map[string]any
|
||||
@@ -1871,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
}
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
// Collect text from each part for aggregation
|
||||
for _, part := range parts {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
collectedTextParts = append(collectedTextParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1885,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
}
|
||||
}
|
||||
|
||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||||
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||
}
|
||||
|
||||
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
|
||||
@@ -1898,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
// mergeCollectedTextParts merges all collected text chunks into the final response.
|
||||
// This fixes the issue where non-streaming responses only returned the last chunk
|
||||
// instead of the complete aggregated text.
|
||||
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
|
||||
if len(textParts) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
// Join all text parts
|
||||
mergedText := strings.Join(textParts, "")
|
||||
|
||||
// Deep copy response
|
||||
result := make(map[string]any)
|
||||
for k, v := range response {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Get or create candidates
|
||||
candidates, ok := result["candidates"].([]any)
|
||||
if !ok || len(candidates) == 0 {
|
||||
candidates = []any{map[string]any{}}
|
||||
}
|
||||
|
||||
// Get first candidate
|
||||
candidate, ok := candidates[0].(map[string]any)
|
||||
if !ok {
|
||||
candidate = make(map[string]any)
|
||||
candidates[0] = candidate
|
||||
}
|
||||
|
||||
// Get or create content
|
||||
content, ok := candidate["content"].(map[string]any)
|
||||
if !ok {
|
||||
content = map[string]any{"role": "model"}
|
||||
candidate["content"] = content
|
||||
}
|
||||
|
||||
// Get existing parts
|
||||
existingParts, ok := content["parts"].([]any)
|
||||
if !ok {
|
||||
existingParts = []any{}
|
||||
}
|
||||
|
||||
// Find and update first text part, or create new one
|
||||
newParts := make([]any, 0, len(existingParts)+1)
|
||||
textUpdated := false
|
||||
|
||||
for _, p := range existingParts {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
newParts = append(newParts, p)
|
||||
continue
|
||||
}
|
||||
if _, hasText := pm["text"]; hasText && !textUpdated {
|
||||
// Replace with merged text
|
||||
newPart := make(map[string]any)
|
||||
for k, v := range pm {
|
||||
newPart[k] = v
|
||||
}
|
||||
newPart["text"] = mergedText
|
||||
newParts = append(newParts, newPart)
|
||||
textUpdated = true
|
||||
} else {
|
||||
newParts = append(newParts, pm)
|
||||
}
|
||||
}
|
||||
|
||||
if !textUpdated {
|
||||
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
|
||||
}
|
||||
|
||||
content["parts"] = newParts
|
||||
result["candidates"] = candidates
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type geminiNativeStreamResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
@@ -2816,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
|
||||
var req struct {
|
||||
GenerationConfig *struct {
|
||||
ImageConfig *struct {
|
||||
ImageSize string `json:"imageSize"`
|
||||
} `json:"imageConfig"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return "2K"
|
||||
}
|
||||
|
||||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||||
if size == "1K" || size == "2K" || size == "4K" {
|
||||
return size
|
||||
}
|
||||
}
|
||||
|
||||
return "2K"
|
||||
}
|
||||
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
|
||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
if m.listByPlatformFunc != nil {
|
||||
return m.listByPlatformFunc(ctx, platforms)
|
||||
}
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
if m.listByGroupFunc != nil {
|
||||
return m.listByGroupFunc(ctx, groupID, platforms)
|
||||
}
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if m.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if m.deletedSessions == nil {
|
||||
m.deletedSessions = make(map[string]int)
|
||||
}
|
||||
m.deletedSessions[sessionHash]++
|
||||
delete(m.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||
})
|
||||
|
||||
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
|
||||
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return nil, nil
|
||||
},
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
}, nil
|
||||
},
|
||||
accountsByID: map[int64]*Account{
|
||||
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-999": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return nil, errors.New("query failed")
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "query accounts failed")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
|
||||
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
||||
// 以避免跨账号签名验证错误。
|
||||
//
|
||||
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
||||
//
|
||||
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
||||
// to avoid cross-account signature validation errors.
|
||||
//
|
||||
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||
// By removing these signatures, we allow the new account to generate valid signatures.
|
||||
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var data any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
|
||||
return body
|
||||
}
|
||||
|
||||
// 递归清理 thoughtSignature
|
||||
cleaned := cleanThoughtSignaturesRecursive(data)
|
||||
|
||||
// 重新序列化
|
||||
result, err := json.Marshal(cleaned)
|
||||
if err != nil {
|
||||
// 如果序列化失败,返回原始 body
|
||||
return body
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
||||
func cleanThoughtSignaturesRecursive(data any) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
// 创建新的 map,移除 thoughtSignature
|
||||
result := make(map[string]any, len(v))
|
||||
for key, value := range v {
|
||||
// 跳过 thoughtSignature 字段
|
||||
if key == "thoughtSignature" {
|
||||
continue
|
||||
}
|
||||
// 递归处理嵌套结构
|
||||
result[key] = cleanThoughtSignaturesRecursive(value)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = cleanThoughtSignaturesRecursive(item)
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
// 基本类型(string, number, bool, null)直接返回
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > geminiTokenCacheSkew:
|
||||
ttl = until - geminiTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > geminiTokenCacheSkew:
|
||||
ttl = until - geminiTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
|
||||
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
|
||||
|
||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||
@@ -123,6 +122,7 @@ type TokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
OrgUUID string `json:"org_uuid,omitempty"`
|
||||
AccountUUID string `json:"account_uuid,omitempty"`
|
||||
EmailAddress string `json:"email_address,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
@@ -176,7 +176,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
||||
}
|
||||
|
||||
// Determine scope and if this is a setup token
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
// Internal API call uses ScopeAPI (org:create_api_key not supported)
|
||||
scope := oauth.ScopeAPI
|
||||
isSetupToken := false
|
||||
if input.Scope == "inference" {
|
||||
scope = oauth.ScopeInference
|
||||
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
if tokenResp.Account != nil {
|
||||
if tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
}
|
||||
if tokenResp.Account.EmailAddress != "" {
|
||||
tokenInfo.EmailAddress = tokenResp.Account.EmailAddress
|
||||
log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress)
|
||||
}
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
|
||||
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
|
||||
type NormalizedCodexLimits struct {
|
||||
Used5hPercent *float64
|
||||
Reset5hSeconds *int
|
||||
Window5hMinutes *int
|
||||
Used7dPercent *float64
|
||||
Reset7dSeconds *int
|
||||
Window7dMinutes *int
|
||||
}
|
||||
|
||||
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
|
||||
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
|
||||
// Returns nil if snapshot is nil or has no useful data.
|
||||
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &NormalizedCodexLimits{}
|
||||
|
||||
primaryMins := 0
|
||||
secondaryMins := 0
|
||||
hasPrimaryWindow := false
|
||||
hasSecondaryWindow := false
|
||||
|
||||
if s.PrimaryWindowMinutes != nil {
|
||||
primaryMins = *s.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
if s.SecondaryWindowMinutes != nil {
|
||||
secondaryMins = *s.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine mapping based on window_minutes
|
||||
use5hFromPrimary := false
|
||||
use7dFromPrimary := false
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both known: smaller window is 5h, larger is 7d
|
||||
if primaryMins < secondaryMins {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
|
||||
if primaryMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary known: classify by threshold
|
||||
if secondaryMins <= 360 {
|
||||
// 5h from secondary, so primary (if any data) is 7d
|
||||
use7dFromPrimary = true
|
||||
} else {
|
||||
// 7d from secondary, so primary (if any data) is 5h
|
||||
use5hFromPrimary = true
|
||||
}
|
||||
} else {
|
||||
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
|
||||
// Assign values
|
||||
if use5hFromPrimary {
|
||||
result.Used5hPercent = s.PrimaryUsedPercent
|
||||
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.PrimaryWindowMinutes
|
||||
result.Used7dPercent = s.SecondaryUsedPercent
|
||||
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.SecondaryWindowMinutes
|
||||
} else if use7dFromPrimary {
|
||||
result.Used7dPercent = s.PrimaryUsedPercent
|
||||
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.PrimaryWindowMinutes
|
||||
result.Used5hPercent = s.SecondaryUsedPercent
|
||||
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.SecondaryWindowMinutes
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -180,67 +266,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
cacheKey := "openai:" + sessionHash
|
||||
|
||||
// 1. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
// 2. 获取可调度的 OpenAI 账号
|
||||
// Get schedulable OpenAI accounts
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -249,14 +294,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
// 4. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
||||
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
// Skip excluded accounts
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
|
||||
//
|
||||
// isBetterAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
|
||||
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
// Higher priority (lower value)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
// Same priority, compare last used time
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,保持
|
||||
return false
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
@@ -325,29 +494,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
}
|
||||
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -778,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
@@ -1576,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
snapshot := &OpenAICodexUsageSnapshot{}
|
||||
hasData := false
|
||||
|
||||
@@ -1651,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
|
||||
// Save raw primary/secondary fields for debugging/tracing
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
@@ -1674,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Normalize to canonical 5h/7d fields based on window_minutes
|
||||
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
||||
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
||||
|
||||
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
||||
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
||||
|
||||
var primaryWindowMins, secondaryWindowMins int
|
||||
var hasPrimaryWindow, hasSecondaryWindow bool
|
||||
|
||||
// Only use window_minutes for reliable window size comparison
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine which is 5h and which is 7d
|
||||
var use5hFromPrimary, use7dFromPrimary bool
|
||||
var use5hFromSecondary, use7dFromSecondary bool
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
||||
if primaryWindowMins < secondaryWindowMins {
|
||||
use5hFromPrimary = true
|
||||
use7dFromSecondary = true
|
||||
} else {
|
||||
use5hFromSecondary = true
|
||||
use7dFromPrimary = true
|
||||
// Normalize to canonical 5h/7d fields
|
||||
if normalized := snapshot.Normalize(); normalized != nil {
|
||||
if normalized.Used5hPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary window size known: classify by absolute threshold
|
||||
if primaryWindowMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary window size known: classify by absolute threshold
|
||||
if secondaryWindowMins <= 360 {
|
||||
use5hFromSecondary = true
|
||||
} else {
|
||||
use7dFromSecondary = true
|
||||
if normalized.Window5hMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
|
||||
}
|
||||
} else {
|
||||
// No window_minutes available: cannot reliably determine window types
|
||||
// Fall back to legacy assumption (may be incorrect)
|
||||
// Assume primary=7d, secondary=5h based on historical observation
|
||||
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
||||
use5hFromSecondary = true
|
||||
if normalized.Used7dPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
|
||||
}
|
||||
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset7dSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 5h fields
|
||||
if use5hFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use5hFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 7d fields
|
||||
if use7dFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use7dFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
if normalized.Window7dMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
return &r.accounts[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
acquireResults map[int64]bool
|
||||
waitCounts map[int64]int
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -42,8 +73,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if c.loadBatchErr != nil {
|
||||
return nil, c.loadBatchErr
|
||||
}
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
if c.skipDefaultLoad && c.loadMap != nil {
|
||||
for _, acc := range accounts {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
for _, acc := range accounts {
|
||||
if c.loadMap != nil {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
continue
|
||||
}
|
||||
}
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
@@ -92,6 +140,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type stubGatewayCache struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if c.sessionBindings == nil {
|
||||
c.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
c.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if c.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if c.deletedSessions == nil {
|
||||
c.deletedSessions = make(map[string]int)
|
||||
}
|
||||
c.deletedSessions[sessionHash]++
|
||||
delete(c.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
@@ -182,6 +275,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-1"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", acc)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-2"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", selection)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unsupported model")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account for unsupported model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "supporting model") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection")
|
||||
}
|
||||
if selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
||||
}
|
||||
if cache.sessionBindings["openai:fallback"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan fallback")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
||||
sessionHash := "bind"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session binding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
||||
sessionHash := "sticky-wait"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
waitCounts: map[int64]int{1: 0},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected sticky wait plan")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 80},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
if cache.sessionBindings["openai:load"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
||||
sessionHash := "excluded"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
||||
sessionHash := "non-openai"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no accounts")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
resetAt := time.Now().Add(1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no candidates")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 100},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 50},
|
||||
},
|
||||
skipDefaultLoad: true,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
lastUsed := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
@@ -2,9 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
|
||||
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 解析重置时间戳
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
|
||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
if resetTimestamp == "" {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
case PlatformGemini, PlatformAntigravity:
|
||||
// 尝试解析 Gemini 格式(用于其他平台)
|
||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
|
||||
return strings.Contains(msg, "sonnet")
|
||||
}
|
||||
|
||||
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
|
||||
// 返回 nil 表示无法从响应头中确定重置时间
|
||||
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 判断哪个限制被触发(used_percent >= 100)
|
||||
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
|
||||
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
|
||||
|
||||
// 优先使用被触发限制的重置时间
|
||||
if is7dExhausted && normalized.Reset7dSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
if is5hExhausted && normalized.Reset5hSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
// 都未达到100%但收到429,使用较长的重置时间
|
||||
var maxResetSecs int
|
||||
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset7dSeconds
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset5hSeconds
|
||||
}
|
||||
if maxResetSecs > 0 {
|
||||
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
|
||||
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
// {
|
||||
// "error": {
|
||||
// "message": "The usage limit has been reached",
|
||||
// "type": "usage_limit_reached",
|
||||
// "resets_at": 1769404154,
|
||||
// "resets_in_seconds": 133107
|
||||
// }
|
||||
// }
|
||||
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errObj, ok := parsed["error"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
|
||||
errType, _ := errObj["type"].(string)
|
||||
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先使用 resets_at(Unix 时间戳)
|
||||
if resetsAt, ok := errObj["resets_at"].(float64); ok {
|
||||
ts := int64(resetsAt)
|
||||
return &ts
|
||||
}
|
||||
if resetsAt, ok := errObj["resets_at"].(string); ok {
|
||||
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 resets_at,尝试使用 resets_in_seconds
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
|
||||
ts := time.Now().Unix() + int64(resetsInSeconds)
|
||||
return &ts
|
||||
}
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
|
||||
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
|
||||
ts := time.Now().Unix() + sec
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
// 根据配置设置过载冷却时间
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 7d limit is exhausted (100% used)
|
||||
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 384607 seconds from now
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 5h limit is exhausted (100% used)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "50")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 3600 seconds from now
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Neither limit at 100%, should use the longer reset time
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "80")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "100000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "90")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should use the max (100000 seconds from 7d window)
|
||||
expectedDuration := 100000 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// No codex headers at all
|
||||
headers := http.Header{}
|
||||
headers.Set("content-type", "application/json")
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
|
||||
headers.Set("x-codex-secondary-used-percent", "50")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should correctly identify that primary is 5h (smaller window) and use its reset time
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
pReset := 384607
|
||||
pWindow := 10080
|
||||
sUsed := 3.0
|
||||
sReset := 17369
|
||||
sWindow := 300
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
PrimaryWindowMinutes: &pWindow,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
SecondaryWindowMinutes: &sWindow,
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Primary has larger window (10080 > 300), so primary should be 7d
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
|
||||
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
|
||||
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
|
||||
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
|
||||
// Test when only primary has data, no window_minutes
|
||||
pUsed := 80.0
|
||||
pReset := 50000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
// No window_minutes, no secondary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
|
||||
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
|
||||
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
// Secondary (5h) should be nil
|
||||
if normalized.Used5hPercent != nil {
|
||||
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
|
||||
// Test when only secondary has data, no window_minutes
|
||||
sUsed := 60.0
|
||||
sReset := 3000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes, no primary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
// So secondary goes to 5h
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
|
||||
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
|
||||
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
// Primary (7d) should be nil
|
||||
if normalized.Used7dPercent != nil {
|
||||
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
|
||||
// Test when both have data but no window_minutes
|
||||
pUsed := 100.0
|
||||
pReset := 400000
|
||||
sUsed := 50.0
|
||||
sReset := 10000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
|
||||
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
|
||||
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
|
||||
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
|
||||
// Verify that Anthropic platform accounts still use the original logic
|
||||
// This test ensures we don't break existing Claude account rate limiting
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate Anthropic 429 headers
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
|
||||
|
||||
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
|
||||
// because it only handles OpenAI platform
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there are no x-codex-* headers
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
|
||||
// This is the exact scenario from the user:
|
||||
// codex_7d_used_percent: 100
|
||||
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
|
||||
// codex_5h_used_percent: 3
|
||||
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers matching user's data
|
||||
// Note: We need to map the canonical 5h/7d back to primary/secondary
|
||||
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt for user scenario")
|
||||
}
|
||||
|
||||
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
|
||||
// Verify it's approximately 4.45 days (384607 seconds)
|
||||
duration := resetAt.Sub(before)
|
||||
actualDays := duration.Hours() / 24.0
|
||||
|
||||
// 384607 / 86400 = ~4.45 days
|
||||
if actualDays < 4.4 || actualDays > 4.5 {
|
||||
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
|
||||
}
|
||||
|
||||
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
|
||||
// Test that we return nil when there's used_percent but no reset_after_seconds
|
||||
// This should cause the caller to use the default 5-minute fallback
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
// No reset_after_seconds!
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there's no reset time available
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user