mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0b5e5bfa0 | ||
|
|
41d0657330 | ||
|
|
1a0cabbfd6 | ||
|
|
9b6dcc57bd | ||
|
|
b17704d6ef | ||
|
|
496469ac4e | ||
|
|
c1b52615be | ||
|
|
3af9940b85 | ||
|
|
22b1277572 | ||
|
|
aff98d5ae1 | ||
|
|
4e1bb2b445 | ||
|
|
dac6e52091 | ||
|
|
8987e0ba67 | ||
|
|
9d1751ec57 | ||
|
|
5d1c12e60e | ||
|
|
5b63a9b02d | ||
|
|
641e61073f | ||
|
|
095f457c57 | ||
|
|
1e57e88e43 | ||
|
|
b95ffce244 | ||
|
|
8f28a834f8 | ||
|
|
7424c73b05 | ||
|
|
1afd81b019 | ||
|
|
732d6495ea | ||
|
|
6d20ab8082 | ||
|
|
aa8ee33b0a | ||
|
|
5f630fbb19 | ||
|
|
bdbd2916f5 | ||
|
|
6dc89765fd | ||
|
|
f3233db01f | ||
|
|
6e12578bc5 | ||
|
|
a25faecadd | ||
|
|
5862e2d8d9 | ||
|
|
66d6454535 | ||
|
|
165553cfb0 | ||
|
|
b5467d610a | ||
|
|
57ff97960d | ||
|
|
5b5db88550 | ||
|
|
f03de00cb9 | ||
|
|
76aae5aa74 | ||
|
|
27ee141c1e | ||
|
|
e65574dea9 | ||
|
|
1ce9dc03f9 | ||
|
|
15ce914a62 | ||
|
|
959af1c8f6 | ||
|
|
c4d496da18 | ||
|
|
f3ea878ba2 | ||
|
|
d80469ea35 | ||
|
|
5fc30ea964 | ||
|
|
f68909a68b | ||
|
|
d162604f32 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
docs/claude-relay-service/
|
||||
.codex
|
||||
|
||||
# ===================
|
||||
# Go 后端
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.116
|
||||
0.1.118
|
||||
|
||||
@@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
|
||||
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
|
||||
announcementHandler := handler.NewAnnouncementHandler(announcementService)
|
||||
channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
|
||||
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
|
||||
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
|
||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||
@@ -192,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||
registry := payment.ProvideRegistry()
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
@@ -221,21 +226,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
sqlDB, err := repository.ProvideSQLDB(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
|
||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB)
|
||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
||||
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
|
||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
|
||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -245,9 +242,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
||||
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
|
||||
availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
|
||||
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
||||
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
@@ -263,6 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -652,6 +652,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -682,7 +683,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
183
backend/internal/handler/admin/affiliate_handler.go
Normal file
183
backend/internal/handler/admin/affiliate_handler.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AffiliateHandler handles admin affiliate (邀请返利) management:
|
||||
// listing users with custom settings, updating per-user invite codes
|
||||
// and exclusive rebate rates, and batch operations.
|
||||
type AffiliateHandler struct {
|
||||
affiliateService *service.AffiliateService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewAffiliateHandler creates a new admin affiliate handler.
|
||||
func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
|
||||
return &AffiliateHandler{
|
||||
affiliateService: affiliateService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// ListUsers returns paginated users with custom affiliate settings.
|
||||
// GET /api/v1/admin/affiliates/users
|
||||
func (h *AffiliateHandler) ListUsers(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
search := c.Query("search")
|
||||
|
||||
entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
|
||||
Search: search,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, entries, total, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateUserSettings updates a user's affiliate settings.
|
||||
// PUT /api/v1/admin/affiliates/users/:user_id
|
||||
//
|
||||
// Both fields are optional and applied independently.
|
||||
type UpdateAffiliateUserRequest struct {
|
||||
AffCode *string `json:"aff_code"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
|
||||
// ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
|
||||
// Used to disambiguate from "field not provided".
|
||||
ClearRebateRate bool `json:"clear_rebate_rate"`
|
||||
}
|
||||
|
||||
func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAffiliateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.AffCode != nil {
|
||||
if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.ClearRebateRate {
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
} else if req.AffRebateRatePercent != nil {
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"user_id": userID})
|
||||
}
|
||||
|
||||
// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
|
||||
// the exclusive rebate rate AND regenerates the invite code as a new system
|
||||
// random one. Conceptually this "removes the user from the custom list".
|
||||
//
|
||||
// Both writes happen in this handler; failure of one leaves the other applied,
|
||||
// but the operation is idempotent so the admin can re-run it safely.
|
||||
// DELETE /api/v1/admin/affiliates/users/:user_id
|
||||
func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"user_id": userID})
|
||||
}
|
||||
|
||||
// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
|
||||
//
|
||||
// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
|
||||
// ignored). Otherwise aff_rebate_rate_percent is required and applied to
|
||||
// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
|
||||
// can't distinguish a missing field from `null`, and a silent clear from a
|
||||
// frontend that forgot to include the rate would be a footgun.
|
||||
//
|
||||
// POST /api/v1/admin/affiliates/users/batch-rate
|
||||
type BatchSetRateRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
|
||||
Clear bool `json:"clear"`
|
||||
}
|
||||
|
||||
func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
|
||||
var req BatchSetRateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.BadRequest(c, "user_ids cannot be empty")
|
||||
return
|
||||
}
|
||||
if !req.Clear && req.AffRebateRatePercent == nil {
|
||||
response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
|
||||
return
|
||||
}
|
||||
rate := req.AffRebateRatePercent
|
||||
if req.Clear {
|
||||
rate = nil
|
||||
}
|
||||
if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"affected": len(req.UserIDs)})
|
||||
}
|
||||
|
||||
// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
|
||||
// shared with the frontend's add-custom-user picker.
|
||||
type AffiliateUserSummary struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// LookupUsers searches users by email/username for the "add custom user" modal.
|
||||
// GET /api/v1/admin/affiliates/users/lookup?q=
|
||||
func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []AffiliateUserSummary{})
|
||||
return
|
||||
}
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
result := make([]AffiliateUserSummary, len(users))
|
||||
for i, u := range users {
|
||||
result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
@@ -185,6 +185,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -241,6 +245,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
@@ -338,6 +344,10 @@ type UpdateSettingsRequest struct {
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
@@ -439,6 +449,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// Available Channels feature switch (user-facing)
|
||||
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -468,6 +481,43 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
affiliateRebateRate := previousSettings.AffiliateRebateRate
|
||||
if req.AffiliateRebateRate != nil {
|
||||
affiliateRebateRate = *req.AffiliateRebateRate
|
||||
}
|
||||
if affiliateRebateRate < service.AffiliateRebateRateMin {
|
||||
affiliateRebateRate = service.AffiliateRebateRateMin
|
||||
}
|
||||
if affiliateRebateRate > service.AffiliateRebateRateMax {
|
||||
affiliateRebateRate = service.AffiliateRebateRateMax
|
||||
}
|
||||
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
|
||||
if req.AffiliateRebateFreezeHours != nil {
|
||||
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
|
||||
}
|
||||
if affiliateRebateFreezeHours < 0 {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
|
||||
if req.AffiliateRebateDurationDays != nil {
|
||||
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
|
||||
}
|
||||
if affiliateRebateDurationDays < 0 {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
|
||||
}
|
||||
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
|
||||
if req.AffiliateRebatePerInviteeCap != nil {
|
||||
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
|
||||
}
|
||||
if affiliateRebatePerInviteeCap < 0 {
|
||||
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
|
||||
if req.TableDefaultPageSize <= 0 {
|
||||
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
|
||||
@@ -1119,6 +1169,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -1252,6 +1306,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.AvailableChannelsEnabled
|
||||
}(),
|
||||
AffiliateEnabled: func() bool {
|
||||
if req.AffiliateEnabled != nil {
|
||||
return *req.AffiliateEnabled
|
||||
}
|
||||
return previousSettings.AffiliateEnabled
|
||||
}(),
|
||||
}
|
||||
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
@@ -1433,6 +1493,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -1488,6 +1552,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
@@ -1738,6 +1804,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.DefaultBalance != after.DefaultBalance {
|
||||
changed = append(changed, "default_balance")
|
||||
}
|
||||
if before.AffiliateRebateRate != after.AffiliateRebateRate {
|
||||
changed = append(changed, "affiliate_rebate_rate")
|
||||
}
|
||||
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
|
||||
changed = append(changed, "affiliate_rebate_freeze_hours")
|
||||
}
|
||||
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
|
||||
changed = append(changed, "affiliate_rebate_duration_days")
|
||||
}
|
||||
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
|
||||
changed = append(changed, "affiliate_rebate_per_invitee_cap")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
@@ -1853,6 +1931,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
|
||||
changed = append(changed, "available_channels_enabled")
|
||||
}
|
||||
if before.AffiliateEnabled != after.AffiliateEnabled {
|
||||
changed = append(changed, "affiliate_enabled")
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ type RegisterRequest struct {
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
InvitationCode string `json:"invitation_code"` // 邀请码
|
||||
AffCode string `json:"aff_code"` // 邀请返利码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
_, user, err := h.authService.RegisterWithVerification(
|
||||
c.Request.Context(),
|
||||
req.Email,
|
||||
req.Password,
|
||||
req.VerifyCode,
|
||||
req.PromoCode,
|
||||
req.InvitationCode,
|
||||
req.AffCode,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
|
||||
VerifyCode string `json:"verify_code,omitempty"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
strings.TrimSpace(req.AffCode),
|
||||
); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
|
||||
@@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
nil,
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
nil,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
|
||||
@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
|
||||
|
||||
type completeOIDCOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := &AuthHandler{authService: authService}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
|
||||
|
||||
type completeWeChatOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
|
||||
@@ -106,10 +106,14 @@ type SystemSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -191,6 +195,9 @@ type SystemSettings struct {
|
||||
|
||||
// Available Channels feature switch (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -243,6 +250,8 @@ type PublicSettings struct {
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
|
||||
@@ -34,6 +34,7 @@ type AdminHandlers struct {
|
||||
ChannelMonitor *admin.ChannelMonitorHandler
|
||||
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
||||
Payment *admin.PaymentHandler
|
||||
Affiliate *admin.AffiliateHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -130,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
@@ -153,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
|
||||
@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
@@ -238,6 +238,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
|
||||
requireCompact := isOpenAIRemoteCompactPath(c)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
@@ -256,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
requireCompact,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
@@ -263,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -644,6 +650,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
currentRoutingModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_messages.account_select_failed",
|
||||
@@ -1167,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
|
||||
@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
||||
require.NoError(t, err)
|
||||
|
||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
|
||||
require.NoError(t, err)
|
||||
|
||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
affiliateService *service.AffiliateService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
@@ -26,12 +27,14 @@ func NewUserHandler(
|
||||
authService *service.AuthService,
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
affiliateService *service.AffiliateService,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +162,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
response.Success(c, profileResp)
|
||||
}
|
||||
|
||||
// GetAffiliate returns the current user's affiliate details.
|
||||
// GET /api/v1/user/aff
|
||||
func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, detail)
|
||||
}
|
||||
|
||||
// TransferAffiliateQuota transfers all available affiliate quota into current balance.
|
||||
// POST /api/v1/user/aff/transfer
|
||||
func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"transferred_quota": transferred,
|
||||
"balance": balance,
|
||||
})
|
||||
}
|
||||
|
||||
type StartIdentityBindingRequest struct {
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
|
||||
@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -37,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
channelMonitorHandler *admin.ChannelMonitorHandler,
|
||||
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
||||
paymentHandler *admin.PaymentHandler,
|
||||
affiliateHandler *admin.AffiliateHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
|
||||
ChannelMonitor: channelMonitorHandler,
|
||||
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
||||
Payment: paymentHandler,
|
||||
Affiliate: affiliateHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,6 +171,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewChannelMonitorHandler,
|
||||
admin.NewChannelMonitorRequestTemplateHandler,
|
||||
admin.NewPaymentHandler,
|
||||
admin.NewAffiliateHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
easypayStatusPaid = 1
|
||||
easypayHTTPTimeout = 10 * time.Second
|
||||
maxEasypayResponseSize = 1 << 20 // 1MB
|
||||
maxEasypayErrorSummary = 512
|
||||
tradeStatusSuccess = "TRADE_SUCCESS"
|
||||
signTypeMD5 = "MD5"
|
||||
paymentModePopup = "popup"
|
||||
@@ -42,17 +43,55 @@ type EasyPay struct {
|
||||
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
|
||||
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
|
||||
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
|
||||
if config[k] == "" {
|
||||
if strings.TrimSpace(config[k]) == "" {
|
||||
return nil, fmt.Errorf("easypay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
cfg := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
cfg[k] = v
|
||||
}
|
||||
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
|
||||
return &EasyPay{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
config: cfg,
|
||||
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeEasyPayAPIBase(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
|
||||
}
|
||||
|
||||
func trimEasyPayEndpointPath(path string) string {
|
||||
path = strings.TrimRight(strings.TrimSpace(path), "/")
|
||||
lower := strings.ToLower(path)
|
||||
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
|
||||
if strings.HasSuffix(lower, endpoint) {
|
||||
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (e *EasyPay) apiBase() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeEasyPayAPIBase(e.config["apiBase"])
|
||||
}
|
||||
|
||||
func (e *EasyPay) Name() string { return "EasyPay" }
|
||||
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
|
||||
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
|
||||
@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
|
||||
for k, v := range params {
|
||||
q.Set(k, v)
|
||||
}
|
||||
base := strings.TrimRight(e.config["apiBase"], "/")
|
||||
payURL := base + "/submit.php?" + q.Encode()
|
||||
payURL := e.apiBase() + "/submit.php?" + q.Encode()
|
||||
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
|
||||
}
|
||||
|
||||
@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
|
||||
params["sign"] = easyPaySign(params, e.config["pkey"])
|
||||
params["sign_type"] = signTypeMD5
|
||||
|
||||
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay create: %w", err)
|
||||
}
|
||||
@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
|
||||
"act": "order", "pid": e.config["pid"],
|
||||
"key": e.config["pkey"], "out_trade_no": tradeNo,
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay query: %w", err)
|
||||
}
|
||||
@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
|
||||
}
|
||||
|
||||
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"],
|
||||
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
|
||||
attempts := e.refundAttempts(req)
|
||||
if len(attempts) == 0 {
|
||||
return nil, fmt.Errorf("easypay refund missing order identifier")
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund: %w", err)
|
||||
var firstErr error
|
||||
for i, attempt := range attempts {
|
||||
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund request: %w", err)
|
||||
}
|
||||
if err := parseEasyPayRefundResponse(status, body); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
type easyPayRefundAttempt struct {
|
||||
params map[string]string
|
||||
refundID string
|
||||
}
|
||||
|
||||
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
|
||||
base := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
|
||||
}
|
||||
var attempts []easyPayRefundAttempt
|
||||
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["out_trade_no"] = orderID
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
|
||||
}
|
||||
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["trade_no"] = tradeNo
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isEasyPayRefundOrderNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
lower := strings.ToLower(msg)
|
||||
return strings.Contains(msg, "订单编号不存在") ||
|
||||
strings.Contains(msg, "订单不存在") ||
|
||||
strings.Contains(lower, "order not found") ||
|
||||
strings.Contains(lower, "not exist")
|
||||
}
|
||||
|
||||
func parseEasyPayRefundResponse(status int, body []byte) error {
|
||||
summary := summarizeEasyPayResponse(body)
|
||||
if status < http.StatusOK || status >= http.StatusMultipleChoices {
|
||||
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(trimmed)
|
||||
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
|
||||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Code any `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse refund: %w", err)
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
if resp.Code != easypayCodeSuccess {
|
||||
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
|
||||
if !easyPayResponseCodeIsSuccess(resp.Code) {
|
||||
msg := strings.TrimSpace(resp.Msg)
|
||||
if msg == "" {
|
||||
msg = summary
|
||||
}
|
||||
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func easyPayResponseCodeIsSuccess(code any) bool {
|
||||
switch v := code.(type) {
|
||||
case float64:
|
||||
return int(v) == easypayCodeSuccess
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
return err == nil && n == easypayCodeSuccess
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeEasyPayResponse(body []byte) string {
|
||||
summary := strings.Join(strings.Fields(string(body)), " ")
|
||||
if summary == "" {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(summary) > maxEasypayErrorSummary {
|
||||
return summary[:maxEasypayErrorSummary] + "..."
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
}
|
||||
|
||||
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
|
||||
body, _, err := e.postRaw(ctx, endpoint, params)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
client := e.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: easypayHTTPTimeout}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func easyPaySign(params map[string]string, pkey string) string {
|
||||
|
||||
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestNormalizeEasyPayAPIBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
|
||||
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotPath string
|
||||
var gotQuery url.Values
|
||||
var gotForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotQuery = r.URL.Query()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForm = r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
|
||||
t.Fatalf("Refund response = %+v, want success", resp)
|
||||
}
|
||||
if gotPath != "/api.php" {
|
||||
t.Fatalf("refund path = %q, want /api.php", gotPath)
|
||||
}
|
||||
if gotQuery.Get("act") != "refund" {
|
||||
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
|
||||
}
|
||||
for key, want := range map[string]string{
|
||||
"pid": "pid-1",
|
||||
"key": "pkey-1",
|
||||
"out_trade_no": "out-456",
|
||||
"money": "1.50",
|
||||
} {
|
||||
if got := gotForm.Get(key); got != want {
|
||||
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
|
||||
}
|
||||
}
|
||||
if got := gotForm.Get("trade_no"); got != "" {
|
||||
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotForms []url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api.php" {
|
||||
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("act") != "refund" {
|
||||
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForms = append(gotForms, r.PostForm)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if len(gotForms) == 1 {
|
||||
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
|
||||
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
|
||||
}
|
||||
if len(gotForms) != 2 {
|
||||
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
|
||||
}
|
||||
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
|
||||
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[0].Get("trade_no"); got != "" {
|
||||
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
|
||||
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
|
||||
}
|
||||
if got := gotForms[1].Get("out_trade_no"); got != "" {
|
||||
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundResponseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
|
||||
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
|
||||
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
|
||||
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, _ = w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL)
|
||||
_, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Refund returned nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
|
||||
t.Helper()
|
||||
|
||||
provider, err := NewEasyPay("test-instance", map[string]string{
|
||||
"pid": "pid-1",
|
||||
"pkey": "pkey-1",
|
||||
"apiBase": apiBase,
|
||||
"notifyUrl": "https://example.com/notify",
|
||||
"returnUrl": "https://example.com/return",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewEasyPay: %v", err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Cached response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 3318, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached_clamp",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 5,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 150,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 0, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
@@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, 3318, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingToolCall(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
|
||||
@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
|
||||
if usage == nil {
|
||||
return AnthropicUsage{}
|
||||
}
|
||||
|
||||
cachedTokens := 0
|
||||
if usage.InputTokensDetails != nil {
|
||||
cachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens - cachedTokens
|
||||
if inputTokens < 0 {
|
||||
inputTokens = 0
|
||||
}
|
||||
|
||||
return AnthropicUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
@@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
|
||||
state.InputTokens = usage.InputTokens
|
||||
state.OutputTokens = usage.OutputTokens
|
||||
state.CacheReadInputTokens = usage.CacheReadInputTokens
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
|
||||
@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
|
||||
var out []AnthropicTool
|
||||
for _, t := range tools {
|
||||
switch t.Type {
|
||||
case "web_search":
|
||||
case "web_search", "google_search", "web_search_20250305":
|
||||
out = append(out, AnthropicTool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
|
||||
@@ -12,17 +12,23 @@ import "encoding/json"
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
// Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
|
||||
// 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
|
||||
// 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
|
||||
// RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
|
||||
// user_id,进而导致请求被归类为第三方 app。
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
|
||||
// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
|
||||
type AnthropicCacheControl struct {
|
||||
Type string `json:"type"` // "ephemeral"
|
||||
TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m(由 Anthropic 判定)
|
||||
}
|
||||
|
||||
// AnthropicResponse is the non-streaming response from POST /v1/messages.
|
||||
|
||||
@@ -4,6 +4,12 @@ package claude
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// Beta header 常量
|
||||
//
|
||||
// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04)。
|
||||
// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
|
||||
// 原因:Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
|
||||
// 缺少任何"官方 Claude Code 请求才会带"的 beta,都会被降级到第三方额度,
|
||||
// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
|
||||
const (
|
||||
BetaOAuth = "oauth-2025-04-20"
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
@@ -12,6 +18,13 @@ const (
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
BetaContext1M = "context-1m-2025-08-07"
|
||||
BetaFastMode = "fast-mode-2026-02-01"
|
||||
|
||||
// 新增(对齐官方 CLI 2.1.9x 以来的流量)
|
||||
BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
|
||||
BetaEffort = "effort-2025-11-24"
|
||||
BetaRedactThinking = "redact-thinking-2026-02-12"
|
||||
BetaContextManagement = "context-management-2025-06-27"
|
||||
BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
|
||||
)
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
|
||||
// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
|
||||
// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
|
||||
const DefaultCacheControlTTL = "5m"
|
||||
|
||||
// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver)。
|
||||
// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
|
||||
// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
|
||||
const CLICurrentVersion = "2.1.92"
|
||||
|
||||
// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
|
||||
// 用于 OAuth 账号伪装成 Claude Code 时使用。
|
||||
// 顺序与真实 CLI 抓包一致。
|
||||
//
|
||||
// 使用建议:
|
||||
// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
|
||||
// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
|
||||
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
|
||||
func FullClaudeCodeMimicryBetas() []string {
|
||||
return []string{
|
||||
BetaClaudeCode,
|
||||
BetaOAuth,
|
||||
BetaInterleavedThinking,
|
||||
BetaPromptCachingScope,
|
||||
BetaEffort,
|
||||
BetaRedactThinking,
|
||||
BetaContextManagement,
|
||||
BetaExtendedCacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
|
||||
"User-Agent": "claude-cli/2.1.22 (external, cli)",
|
||||
// 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
|
||||
"User-Agent": "claude-cli/2.1.92 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.70.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package repository
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
|
||||
updates := map[string]any{
|
||||
"openai_compact_supported": true,
|
||||
"openai_compact_checked_at": "2026-04-10T10:00:00Z",
|
||||
}
|
||||
|
||||
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
|
||||
}
|
||||
}
|
||||
762
backend/internal/repository/affiliate_repo.go
Normal file
762
backend/internal/repository/affiliate_repo.go
Normal file
@@ -0,0 +1,762 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
affiliateCodeLength = 12
|
||||
affiliateCodeMaxAttempts = 12
|
||||
)
|
||||
|
||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||
|
||||
type affiliateQueryExecer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type affiliateRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
|
||||
return &affiliateRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
|
||||
if userID <= 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return ensureUserAffiliateWithClient(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return queryAffiliateByCode(ctx, client, code)
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
|
||||
var bound bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
|
||||
inviterID, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bind inviter: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
bound = false
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
|
||||
inviterID,
|
||||
); err != nil {
|
||||
return fmt.Errorf("increment inviter aff_count: %w", err)
|
||||
}
|
||||
bound = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var applied bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
// freezeHours > 0: add to frozen quota; == 0: add to available quota directly
|
||||
var updateSQL string
|
||||
if freezeHours > 0 {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
} else {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
applied = false
|
||||
return nil
|
||||
}
|
||||
|
||||
if freezeHours > 0 {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, freezeHours); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
applied = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx,
|
||||
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
|
||||
inviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
var total float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return total, rows.Close()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
|
||||
var thawed float64
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
var err error
|
||||
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
|
||||
return err
|
||||
})
|
||||
return thawed, err
|
||||
}
|
||||
|
||||
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
|
||||
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH matured AS (
|
||||
UPDATE user_affiliate_ledger
|
||||
SET frozen_until = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND frozen_until IS NOT NULL
|
||||
AND frozen_until <= NOW()
|
||||
RETURNING amount
|
||||
)
|
||||
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("thaw frozen quota: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var thawed float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&thawed); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if thawed <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err = txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_quota = aff_quota + $1,
|
||||
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, thawed, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("move thawed quota: %w", err)
|
||||
}
|
||||
return thawed, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
var transferred float64
|
||||
var newBalance float64
|
||||
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Thaw any matured frozen quota before transfer.
|
||||
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
|
||||
return fmt.Errorf("thaw before transfer: %w", err)
|
||||
}
|
||||
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH claimed AS (
|
||||
SELECT aff_quota::double precision AS amount
|
||||
FROM user_affiliates
|
||||
WHERE user_id = $1
|
||||
AND aff_quota > 0
|
||||
FOR UPDATE
|
||||
),
|
||||
cleared AS (
|
||||
UPDATE user_affiliates ua
|
||||
SET aff_quota = 0,
|
||||
updated_at = NOW()
|
||||
FROM claimed c
|
||||
WHERE ua.user_id = $1
|
||||
RETURNING c.amount
|
||||
)
|
||||
SELECT amount
|
||||
FROM cleared`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("claim affiliate quota: %w", err)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
_ = rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAffiliateQuotaEmpty
|
||||
}
|
||||
if err := rows.Scan(&transferred); err != nil {
|
||||
_ = rows.Close()
|
||||
return err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if transferred <= 0 {
|
||||
return service.ErrAffiliateQuotaEmpty
|
||||
}
|
||||
|
||||
affected, err := txClient.User.Update().
|
||||
Where(user.IDEQ(userID)).
|
||||
AddBalance(transferred).
|
||||
AddTotalRecharged(transferred).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
newBalance, err = queryUserBalance(txCtx, txClient, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return transferred, newBalance, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.created_at,
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
|
||||
FROM user_affiliates ua
|
||||
LEFT JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = $1
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
WHERE ua.inviter_id = $1
|
||||
GROUP BY ua.user_id, u.email, u.username, ua.created_at
|
||||
ORDER BY ua.created_at DESC
|
||||
LIMIT $2`, inviterID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
invitees := make([]service.AffiliateInvitee, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInvitee
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.CreatedAt = &createdAt
|
||||
invitees = append(invitees, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin affiliate transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := fn(txCtx, tx.Client()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit affiliate transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
||||
summary, err := queryAffiliateByUserID(ctx, client, userID)
|
||||
if err == nil {
|
||||
return summary, nil
|
||||
}
|
||||
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < affiliateCodeMaxAttempts; i++ {
|
||||
code, codeErr := generateAffiliateCode()
|
||||
if codeErr != nil {
|
||||
return nil, codeErr
|
||||
}
|
||||
_, insertErr := client.ExecContext(ctx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
|
||||
VALUES ($1, $2, NOW(), NOW())
|
||||
ON CONFLICT (user_id) DO NOTHING`, userID, code)
|
||||
if insertErr == nil {
|
||||
break
|
||||
}
|
||||
if isAffiliateUniqueViolation(insertErr) {
|
||||
continue
|
||||
}
|
||||
return nil, insertErr
|
||||
}
|
||||
|
||||
return queryAffiliateByUserID(ctx, client, userID)
|
||||
}
|
||||
|
||||
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
aff_code_custom,
|
||||
aff_rebate_rate_percent,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM user_affiliates
|
||||
WHERE user_id = $1`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
var rebateRate sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&out.AffCodeCustom,
|
||||
&rebateRate,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
if rebateRate.Valid {
|
||||
v := rebateRate.Float64
|
||||
out.AffRebateRatePercent = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
aff_code_custom,
|
||||
aff_rebate_rate_percent,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM user_affiliates
|
||||
WHERE aff_code = $1
|
||||
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
var rebateRate sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&out.AffCodeCustom,
|
||||
&rebateRate,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
if rebateRate.Valid {
|
||||
v := rebateRate.Float64
|
||||
out.AffRebateRatePercent = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
|
||||
rows, err := client.QueryContext(ctx,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
var balance float64
|
||||
if err := rows.Scan(&balance); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func generateAffiliateCode() (string, error) {
|
||||
buf := make([]byte, affiliateCodeLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("generate affiliate code: %w", err)
|
||||
}
|
||||
for i := range buf {
|
||||
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func isAffiliateUniqueViolation(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) {
|
||||
return string(pqErr.Code) == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
|
||||
// 唯一性冲突返回 ErrAffiliateCodeTaken。
|
||||
func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
|
||||
if userID <= 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
code := strings.ToUpper(strings.TrimSpace(newCode))
|
||||
if code == "" {
|
||||
return service.ErrAffiliateCodeInvalid
|
||||
}
|
||||
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_code = $1,
|
||||
aff_code_custom = true,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, code, userID)
|
||||
if err != nil {
|
||||
if isAffiliateUniqueViolation(err) {
|
||||
return service.ErrAffiliateCodeTaken
|
||||
}
|
||||
return fmt.Errorf("update aff_code: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
|
||||
func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
|
||||
if userID <= 0 {
|
||||
return "", service.ErrUserNotFound
|
||||
}
|
||||
var newCode string
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < affiliateCodeMaxAttempts; i++ {
|
||||
candidate, codeErr := generateAffiliateCode()
|
||||
if codeErr != nil {
|
||||
return codeErr
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_code = $1,
|
||||
aff_code_custom = false,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, candidate, userID)
|
||||
if err != nil {
|
||||
if isAffiliateUniqueViolation(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("reset aff_code: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
newCode = candidate
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reset aff_code: exhausted attempts")
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return newCode, nil
|
||||
}
|
||||
|
||||
// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
|
||||
func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
|
||||
if userID <= 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
// nullableArg lets us use a single UPDATE for both "set value" and
|
||||
// "clear" cases — database/sql converts nil interface{} to SQL NULL.
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_rebate_rate_percent = $1,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, nullableArg(ratePercent), userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。
|
||||
func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
|
||||
if len(userIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
for _, uid := range userIDs {
|
||||
if uid <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_rebate_rate_percent = $1,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
|
||||
// binding: nil pointer → SQL NULL, non-nil → the float value.
|
||||
func nullableArg(v *float64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
||||
//
|
||||
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
||||
// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
|
||||
// 这避免了为两种情况维护两份 SQL 模板。
|
||||
func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
|
||||
page := filter.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := filter.PageSize
|
||||
if pageSize <= 0 || pageSize > 200 {
|
||||
pageSize = 20
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
|
||||
|
||||
const baseFrom = `
|
||||
FROM user_affiliates ua
|
||||
JOIN users u ON u.id = ua.user_id
|
||||
WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
|
||||
AND (u.email ILIKE $1 OR u.username ILIKE $1)`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
|
||||
}
|
||||
|
||||
listQuery := `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.aff_code,
|
||||
ua.aff_code_custom,
|
||||
ua.aff_rebate_rate_percent,
|
||||
ua.aff_count` + baseFrom + `
|
||||
ORDER BY ua.updated_at DESC
|
||||
LIMIT $2 OFFSET $3`
|
||||
|
||||
rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
entries := make([]service.AffiliateAdminEntry, 0)
|
||||
for rows.Next() {
|
||||
var e service.AffiliateAdminEntry
|
||||
var rebate sql.NullFloat64
|
||||
if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
|
||||
&e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if rebate.Valid {
|
||||
v := rebate.Float64
|
||||
e.AffRebateRatePercent = &v
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return entries, total, nil
|
||||
}
|
||||
|
||||
// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
|
||||
func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
var v int64
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
399
backend/internal/repository/affiliate_repo_integration_test.go
Normal file
399
backend/internal/repository/affiliate_repo_integration_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
|
||||
t.Helper()
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
require.True(t, rows.Next(), "expected one row")
|
||||
var value float64
|
||||
require.NoError(t, rows.Scan(&value))
|
||||
require.NoError(t, rows.Err())
|
||||
return value
|
||||
}
|
||||
|
||||
func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
|
||||
t.Helper()
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
require.True(t, rows.Next(), "expected one row")
|
||||
var value int
|
||||
require.NoError(t, rows.Scan(&value))
|
||||
require.NoError(t, rows.Err())
|
||||
return value
|
||||
}
|
||||
|
||||
func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 5.5,
|
||||
Concurrency: 5,
|
||||
})
|
||||
|
||||
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
_, err := client.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
||||
require.NoError(t, err)
|
||||
|
||||
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 12.34, transferred, 1e-9)
|
||||
require.InDelta(t, 17.84, balance, 1e-9)
|
||||
|
||||
affQuota := querySingleFloat(t, txCtx, client,
|
||||
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
|
||||
require.InDelta(t, 0.0, affQuota, 1e-9)
|
||||
|
||||
persistedBalance := querySingleFloat(t, txCtx, client,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
|
||||
require.InDelta(t, 17.84, persistedBalance, 1e-9)
|
||||
|
||||
ledgerCount := querySingleInt(t, txCtx, client,
|
||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||
require.Equal(t, 1, ledgerCount)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||
// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
|
||||
// that already carries a transaction (via dbent.NewTxContext), repo.withTx
|
||||
// must reuse that tx rather than opening a nested one. If this invariant
|
||||
// breaks, AccrueQuota would commit independently and survive a rollback of
|
||||
// the outer tx, which would violate payment_fulfillment's all-or-nothing
|
||||
// semantics.
|
||||
func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
outerTx, err := integrationEntClient.Tx(ctx)
|
||||
require.NoError(t, err, "begin outer tx")
|
||||
// Defensive cleanup: if any require.* below fires before the explicit
|
||||
// Rollback, this prevents the tx from leaking until container teardown.
|
||||
// Rollback is idempotent at the driver level (extra rollback returns an
|
||||
// error we ignore).
|
||||
t.Cleanup(func() { _ = outerTx.Rollback() })
|
||||
client := outerTx.Client()
|
||||
txCtx := dbent.NewTxContext(ctx, outerTx)
|
||||
|
||||
inviter := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
})
|
||||
invitee := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
})
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
_, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, bound, "invitee must bind to inviter")
|
||||
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
|
||||
require.NoError(t, err)
|
||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||
|
||||
// Visible inside the outer tx.
|
||||
innerQuota := querySingleFloat(t, txCtx, client,
|
||||
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
|
||||
require.InDelta(t, 3.5, innerQuota, 1e-9)
|
||||
|
||||
// Roll back the outer tx; if AccrueQuota had opened its own inner tx and
|
||||
// committed it, the rows would still be visible to the global client.
|
||||
require.NoError(t, outerTx.Rollback())
|
||||
|
||||
rows, err := integrationEntClient.QueryContext(ctx,
|
||||
"SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
|
||||
inviter.ID, invitee.ID)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
require.True(t, rows.Next())
|
||||
var postRollbackCount int
|
||||
require.NoError(t, rows.Scan(&postRollbackCount))
|
||||
require.Equal(t, 0, postRollbackCount,
|
||||
"AccrueQuota must propagate the outer tx — found persisted rows after rollback")
|
||||
}
|
||||
|
||||
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 3.21,
|
||||
Concurrency: 5,
|
||||
})
|
||||
|
||||
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
_, err := client.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
|
||||
VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
|
||||
require.NoError(t, err)
|
||||
|
||||
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
|
||||
require.InDelta(t, 0.0, transferred, 1e-9)
|
||||
require.InDelta(t, 0.0, balance, 1e-9)
|
||||
|
||||
persistedBalance := querySingleFloat(t, txCtx, client,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
|
||||
require.InDelta(t, 3.21, persistedBalance, 1e-9)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminCustomCode covers the success path of admin
|
||||
// invite-code rewrite + reset within a shared test transaction:
|
||||
// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
|
||||
// - the old code can no longer be found
|
||||
// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
|
||||
//
|
||||
// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
|
||||
// test because a unique-violation aborts the surrounding Postgres tx, which
|
||||
// would poison subsequent assertions in the same transaction.
|
||||
func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
|
||||
originalCode := original.AffCode
|
||||
|
||||
// Rewrite to a custom code
|
||||
customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
|
||||
|
||||
updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, customCode, updated.AffCode)
|
||||
require.True(t, updated.AffCodeCustom)
|
||||
|
||||
// Lookup by new custom code finds the user
|
||||
byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u.ID, byCode.UserID)
|
||||
|
||||
// Old system code should no longer match
|
||||
_, err = repo.GetAffiliateByCode(txCtx, originalCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
|
||||
|
||||
// Reset back to a fresh system code, clears custom flag
|
||||
newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, customCode, newSysCode)
|
||||
|
||||
reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newSysCode, reset.AffCode)
|
||||
require.False(t, reset.AffCodeCustom)
|
||||
|
||||
// The old custom code is now free again
|
||||
_, err = repo.GetAffiliateByCode(txCtx, customCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
|
||||
// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
|
||||
// this test must be the only assertion and run in its own tx — production
|
||||
// callers each have their own outer tx, so this matches real behavior.
|
||||
func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
taker := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
requester := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
|
||||
takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
|
||||
|
||||
// Now requester tries to grab the same code → conflict.
|
||||
err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
|
||||
// set/clear and the Batch variant including NULL semantics.
|
||||
func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u1 := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
u2 := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// Set exclusive rate for u1
|
||||
rate := 42.5
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
|
||||
|
||||
got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.AffRebateRatePercent)
|
||||
require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
|
||||
|
||||
// Clear exclusive rate
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
|
||||
cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, cleared.AffRebateRatePercent)
|
||||
|
||||
// Batch set both users
|
||||
batchRate := 15.0
|
||||
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
|
||||
|
||||
for _, uid := range []int64{u1.ID, u2.ID} {
|
||||
v, err := repo.EnsureUserAffiliate(txCtx, uid)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, v.AffRebateRatePercent)
|
||||
require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
|
||||
}
|
||||
|
||||
// Batch clear
|
||||
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
|
||||
for _, uid := range []int64{u1.ID, u2.ID} {
|
||||
v, err := repo.EnsureUserAffiliate(txCtx, uid)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, v.AffRebateRatePercent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
|
||||
// only includes users with at least one override applied.
|
||||
func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
// User without any custom config — should NOT appear in the list.
|
||||
plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
|
||||
uPlain := mustCreateUser(t, client, &service.User{
|
||||
Email: plainEmail, PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
_, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User with a custom code — should appear.
|
||||
uCode := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
|
||||
|
||||
// User with only an exclusive rate — should appear.
|
||||
uRate := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
r := 33.3
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
|
||||
|
||||
entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
|
||||
Page: 1, PageSize: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a quick lookup to assert per-user attributes (other tests may have
|
||||
// inserted custom rows in the same DB; we only care about our 3).
|
||||
byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
|
||||
for _, e := range entries {
|
||||
byUserID[e.UserID] = e
|
||||
}
|
||||
|
||||
require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
|
||||
|
||||
codeEntry, ok := byUserID[uCode.ID]
|
||||
require.True(t, ok, "custom-code user missing from list")
|
||||
require.True(t, codeEntry.AffCodeCustom)
|
||||
require.Nil(t, codeEntry.AffRebateRatePercent)
|
||||
|
||||
rateEntry, ok := byUserID[uRate.ID]
|
||||
require.True(t, ok, "custom-rate user missing from list")
|
||||
require.False(t, rateEntry.AffCodeCustom)
|
||||
require.NotNil(t, rateEntry.AffRebateRatePercent)
|
||||
require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
|
||||
|
||||
require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
|
||||
}
|
||||
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelRepository,
|
||||
NewChannelMonitorRepository,
|
||||
NewChannelMonitorRequestTemplateRepository,
|
||||
NewAffiliateRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"force_email_on_third_party_signup": false,
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": false,
|
||||
"wechat_connect_app_id": "",
|
||||
"wechat_connect_app_secret_configured": false,
|
||||
@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"custom_endpoints": [],
|
||||
"default_concurrency": 0,
|
||||
"default_balance": 0,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": true,
|
||||
"wechat_connect_app_id": "wx-open-config",
|
||||
"wechat_connect_app_secret_configured": true,
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
toucher := &recordingActivityToucher{}
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 渠道监控
|
||||
registerChannelMonitorRoutes(admin, h)
|
||||
|
||||
// 邀请返利(专属用户管理)
|
||||
registerAffiliateRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
|
||||
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
affiliates := admin.Group("/affiliates")
|
||||
{
|
||||
users := affiliates.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.Affiliate.ListUsers)
|
||||
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
||||
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
||||
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
||||
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
user.GET("/aff", h.User.GetAffiliate)
|
||||
user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
|
||||
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
|
||||
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
||||
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
||||
|
||||
@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
const (
|
||||
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
|
||||
OpenAICompactModeAuto = "auto"
|
||||
// OpenAICompactModeForceOn always treats the account as compact-supported.
|
||||
OpenAICompactModeForceOn = "force_on"
|
||||
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
|
||||
OpenAICompactModeForceOff = "force_off"
|
||||
)
|
||||
|
||||
func normalizeOpenAICompactMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAICompactModeForceOn:
|
||||
return OpenAICompactModeForceOn
|
||||
case OpenAICompactModeForceOff:
|
||||
return OpenAICompactModeForceOff
|
||||
default:
|
||||
return OpenAICompactModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
func stringMappingFromRaw(raw any) map[string]string {
|
||||
switch mapping := raw.(type) {
|
||||
case map[string]any:
|
||||
if len(mapping) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]string, len(mapping))
|
||||
for key, value := range mapping {
|
||||
if str, ok := value.(string); ok {
|
||||
result[key] = str
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
case map[string]string:
|
||||
if len(mapping) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]string, len(mapping))
|
||||
for key, value := range mapping {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
credentialsPtr := mapPtr(a.Credentials)
|
||||
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
|
||||
@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
|
||||
return requestedModel, false
|
||||
}
|
||||
|
||||
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
|
||||
// Missing or invalid values fall back to "auto".
|
||||
func (a *Account) GetOpenAICompactMode() string {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return OpenAICompactModeAuto
|
||||
}
|
||||
mode, _ := a.Extra["openai_compact_mode"].(string)
|
||||
return normalizeOpenAICompactMode(mode)
|
||||
}
|
||||
|
||||
// OpenAICompactSupportKnown reports whether compact capability is known for this
|
||||
// account and, when known, whether it is supported.
|
||||
func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
return false, false
|
||||
}
|
||||
|
||||
switch a.GetOpenAICompactMode() {
|
||||
case OpenAICompactModeForceOn:
|
||||
return true, true
|
||||
case OpenAICompactModeForceOff:
|
||||
return false, true
|
||||
}
|
||||
|
||||
if a.Extra == nil {
|
||||
return false, false
|
||||
}
|
||||
supported, ok := a.Extra["openai_compact_supported"].(bool)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return supported, true
|
||||
}
|
||||
|
||||
// AllowsOpenAICompact reports whether the account may be considered for compact
|
||||
// requests. Unknown capability remains allowed to avoid breaking older accounts
|
||||
// before an explicit probe has been run.
|
||||
func (a *Account) AllowsOpenAICompact() bool {
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
return false
|
||||
}
|
||||
supported, known := a.OpenAICompactSupportKnown()
|
||||
if !known {
|
||||
return true
|
||||
}
|
||||
return supported
|
||||
}
|
||||
|
||||
// GetCompactModelMapping returns compact-only model remapping configuration.
|
||||
// This mapping is intended for /responses/compact only and does not affect
|
||||
// normal /responses traffic.
|
||||
func (a *Account) GetCompactModelMapping() map[string]string {
|
||||
if a == nil || a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
|
||||
}
|
||||
|
||||
// ResolveCompactMappedModel resolves compact-only model remapping and reports
|
||||
// whether a compact-specific mapping rule matched.
|
||||
func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetCompactModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel, false
|
||||
}
|
||||
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
|
||||
return mappedModel, true
|
||||
}
|
||||
return requestedModel, false
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
|
||||
369
backend/internal/service/account_openai_compact_test.go
Normal file
369
backend/internal/service/account_openai_compact_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAccountGetOpenAICompactMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil account defaults to auto",
|
||||
want: OpenAICompactModeAuto,
|
||||
},
|
||||
{
|
||||
name: "non openai account defaults to auto",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
|
||||
},
|
||||
want: OpenAICompactModeAuto,
|
||||
},
|
||||
{
|
||||
name: "missing extra defaults to auto",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
},
|
||||
want: OpenAICompactModeAuto,
|
||||
},
|
||||
{
|
||||
name: "invalid mode falls back to auto",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_mode": " invalid "},
|
||||
},
|
||||
want: OpenAICompactModeAuto,
|
||||
},
|
||||
{
|
||||
name: "force on is normalized",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
|
||||
},
|
||||
want: OpenAICompactModeForceOn,
|
||||
},
|
||||
{
|
||||
name: "force off is normalized",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_mode": "force_off"},
|
||||
},
|
||||
want: OpenAICompactModeForceOff,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.account.GetOpenAICompactMode(); got != tt.want {
|
||||
t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountOpenAICompactSupportKnown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
wantSupported bool
|
||||
wantKnown bool
|
||||
}{
|
||||
{
|
||||
name: "nil account is unknown",
|
||||
wantSupported: false,
|
||||
wantKnown: false,
|
||||
},
|
||||
{
|
||||
name: "non openai account is unknown",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Extra: map[string]any{"openai_compact_supported": true},
|
||||
},
|
||||
wantSupported: false,
|
||||
wantKnown: false,
|
||||
},
|
||||
{
|
||||
name: "force on overrides probe state",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{
|
||||
"openai_compact_mode": OpenAICompactModeForceOn,
|
||||
"openai_compact_supported": false,
|
||||
},
|
||||
},
|
||||
wantSupported: true,
|
||||
wantKnown: true,
|
||||
},
|
||||
{
|
||||
name: "force off overrides probe state",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{
|
||||
"openai_compact_mode": OpenAICompactModeForceOff,
|
||||
"openai_compact_supported": true,
|
||||
},
|
||||
},
|
||||
wantSupported: false,
|
||||
wantKnown: true,
|
||||
},
|
||||
{
|
||||
name: "auto true is known supported",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_supported": true},
|
||||
},
|
||||
wantSupported: true,
|
||||
wantKnown: true,
|
||||
},
|
||||
{
|
||||
name: "auto false is known unsupported",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_supported": false},
|
||||
},
|
||||
wantSupported: false,
|
||||
wantKnown: true,
|
||||
},
|
||||
{
|
||||
name: "auto without probe state remains unknown",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{},
|
||||
},
|
||||
wantSupported: false,
|
||||
wantKnown: false,
|
||||
},
|
||||
{
|
||||
name: "invalid probe field remains unknown",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_supported": "true"},
|
||||
},
|
||||
wantSupported: false,
|
||||
wantKnown: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
|
||||
if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
|
||||
t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountAllowsOpenAICompact(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil account does not allow compact",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non openai account does not allow compact",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown openai account remains allowed",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "supported openai account is allowed",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_supported": true},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported openai account is rejected",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_supported": false},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "force on is allowed",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "force off is rejected",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.account.AllowsOpenAICompact(); got != tt.want {
|
||||
t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetCompactModelMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "nil account returns nil",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "missing credentials returns nil",
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "map any is converted",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4-openai-compact",
|
||||
"invalid": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]string{
|
||||
"gpt-5.4": "gpt-5.4-openai-compact",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map string string is copied",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]string{
|
||||
"gpt-*": "compact-*",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]string{
|
||||
"gpt-*": "compact-*",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.account.GetCompactModelMapping()
|
||||
if !equalStringMap(got, tt.want) {
|
||||
t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveCompactMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no compact mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact compact mapping matches",
|
||||
credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4-openai-compact",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4-openai-compact",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough counts as match",
|
||||
credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "longest wildcard wins",
|
||||
credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-*": "fallback-compact",
|
||||
"gpt-5.4*": "gpt-5.4-openai-compact",
|
||||
"gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4-mini",
|
||||
expectedModel: "gpt-5.4-mini-openai-compact",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing compact mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.3": "gpt-5.3-openai-compact",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
|
||||
if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
|
||||
t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func equalStringMap(left, right map[string]string) bool {
|
||||
if len(left) != len(right) {
|
||||
return false
|
||||
}
|
||||
for key, want := range right {
|
||||
if got, ok := left[key]; !ok || got != want {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
// All account types use full Claude Code client characteristics, only auth header differs
|
||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
|
||||
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
|
||||
// Route to platform-specific test method
|
||||
if account.IsOpenAI() {
|
||||
return s.testOpenAIAccountConnection(c, account, modelID, prompt)
|
||||
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
|
||||
}
|
||||
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
|
||||
ctx := c.Request.Context()
|
||||
_ = prompt
|
||||
mode = normalizeAccountTestMode(mode)
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
testModelID := modelID
|
||||
@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
testModelID = openai.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
// Align test routing with gateway behavior: OpenAI accounts apply normal
|
||||
// account model mapping, and compact mode applies compact-only mapping on top.
|
||||
testModelID = account.GetMappedModel(testModelID)
|
||||
if mode == AccountTestModeCompact {
|
||||
testModelID = resolveOpenAICompactForwardModel(account, testModelID)
|
||||
return s.testOpenAICompactConnection(c, account, testModelID)
|
||||
}
|
||||
|
||||
// Route to image generation test if an image model is selected
|
||||
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
|
||||
}
|
||||
// 401 Unauthorized: 标记账号为永久错误
|
||||
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
|
||||
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
|
||||
@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testOpenAICompactConnection probes /responses/compact and persists the
|
||||
// resulting capability state on the account.
|
||||
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
authToken := ""
|
||||
apiURL := ""
|
||||
isOAuth := false
|
||||
chatgptAccountID := ""
|
||||
|
||||
switch {
|
||||
case account.IsOAuth():
|
||||
isOAuth = true
|
||||
authToken = account.GetOpenAIAccessToken()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
apiURL = chatgptCodexAPIURL + "/compact"
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
case account.Type == AccountTypeAPIKey:
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("Originator", "codex_cli_rs")
|
||||
req.Header.Set("User-Agent", codexCLIUserAgent)
|
||||
req.Header.Set("Version", codexCLIVersion)
|
||||
probeSessionID := compactProbeSessionID(account.ID)
|
||||
req.Header.Set("Session_ID", probeSessionID)
|
||||
req.Header.Set("Conversation_ID", probeSessionID)
|
||||
|
||||
if isOAuth {
|
||||
req.Host = "chatgpt.com"
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
if s.accountRepo != nil {
|
||||
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
if s.accountRepo != nil {
|
||||
updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
|
||||
if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
|
||||
updates = mergeExtraUpdates(updates, codexUpdates)
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
|
||||
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
|
||||
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
|
||||
if s == nil || s.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var resetAt *time.Time
|
||||
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
|
||||
resetAt = calculated
|
||||
} else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
|
||||
t := time.Unix(*unixTs, 0)
|
||||
resetAt = &t
|
||||
}
|
||||
if resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
account.RateLimitedAt = &now
|
||||
account.RateLimitResetAt = resetAt
|
||||
|
||||
if account.Status == StatusError {
|
||||
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
|
||||
return
|
||||
}
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
}
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
ctx := c.Request.Context()
|
||||
@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
seenCompleted := false
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
if seenCompleted {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
|
||||
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
if seenCompleted {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
case "response.completed", "response.done":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "response.failed":
|
||||
errorMsg := "OpenAI response failed"
|
||||
if responseData, ok := data["response"].(map[string]any); ok {
|
||||
if errData, ok := responseData["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok && msg != "" {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
updateCalls := make(chan map[string]any, 1)
|
||||
account := Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
repo := &snapshotUpdateAccountRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
updateExtraCalls: updateCalls,
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
|
||||
|
||||
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
|
||||
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
|
||||
require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
|
||||
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
|
||||
require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
|
||||
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
|
||||
updates := <-updateCalls
|
||||
require.Equal(t, true, updates["openai_compact_supported"])
|
||||
require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
|
||||
require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
updateCalls := make(chan map[string]any, 1)
|
||||
account := Account{
|
||||
ID: 2,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
repo := &snapshotUpdateAccountRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
updateExtraCalls: updateCalls,
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`404 page not found`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
|
||||
|
||||
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
|
||||
require.Error(t, err)
|
||||
|
||||
updates := <-updateCalls
|
||||
require.Equal(t, false, updates["openai_compact_supported"])
|
||||
require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
|
||||
require.Contains(t, rec.Body.String(), `"type":"error"`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
updateCalls := make(chan map[string]any, 1)
|
||||
account := Account{
|
||||
ID: 3,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://example.com/v1",
|
||||
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
|
||||
},
|
||||
}
|
||||
repo := &snapshotUpdateAccountRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
updateExtraCalls: updateCalls,
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
|
||||
|
||||
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
updates := <-updateCalls
|
||||
require.Equal(t, true, updates["openai_compact_supported"])
|
||||
}
|
||||
|
||||
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
updateCalls := make(chan map[string]any, 1)
|
||||
account := Account{
|
||||
ID: 4,
|
||||
Name: "openai-apikey-default",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
}
|
||||
repo := &snapshotUpdateAccountRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
updateExtraCalls: updateCalls,
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
|
||||
|
||||
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
|
||||
<-updateCalls
|
||||
}
|
||||
@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
clearedErrorID int64
|
||||
setErrorID int64
|
||||
setErrorMsg string
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
|
||||
r.clearedErrorID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorID = id
|
||||
r.setErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
|
||||
func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
|
||||
|
||||
`))
|
||||
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 90,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, recorder.Body.String(), "response.completed")
|
||||
require.NotContains(t, recorder.Body.String(), `"success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, account.ID, repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.Equal(t, account.ID, repo.clearedErrorID)
|
||||
require.Equal(t, StatusActive, account.Status)
|
||||
require.Empty(t, account.ErrorMessage)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 77,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, account.ID, repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.Equal(t, account.ID, repo.clearedErrorID)
|
||||
require.Equal(t, StatusActive, account.Status)
|
||||
require.Empty(t, account.ErrorMessage)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
require.Empty(t, repo.updatedExtra)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 78,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, account.ID, repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.Zero(t, repo.clearedErrorID)
|
||||
require.Equal(t, StatusActive, account.Status)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 79,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusError,
|
||||
ErrorMessage: "stale 403",
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Zero(t, repo.rateLimitedID)
|
||||
require.Nil(t, repo.rateLimitedAt)
|
||||
require.Zero(t, repo.clearedErrorID)
|
||||
require.Equal(t, StatusError, account.Status)
|
||||
require.Equal(t, "stale 403", account.ErrorMessage)
|
||||
require.Nil(t, account.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 80,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, account.ID, repo.setErrorID)
|
||||
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
|
||||
require.Zero(t, repo.rateLimitedID)
|
||||
require.Zero(t, repo.clearedErrorID)
|
||||
require.Nil(t, account.RateLimitResetAt)
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ const (
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
openAICodexProbeVersion = "0.104.0"
|
||||
openAICodexProbeVersion = "0.125.0"
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
|
||||
490
backend/internal/service/affiliate_service.go
Normal file
490
backend/internal/service/affiliate_service.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
|
||||
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
|
||||
ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
|
||||
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
|
||||
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
|
||||
)
|
||||
|
||||
const (
|
||||
affiliateInviteesLimit = 100
|
||||
// AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
|
||||
// 12-char codes and admin-customized codes (e.g. "VIP2026").
|
||||
AffiliateCodeMinLength = 4
|
||||
AffiliateCodeMaxLength = 32
|
||||
)
|
||||
|
||||
// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
|
||||
// All input passes through strings.ToUpper before validation, so lowercase from
|
||||
// users is normalized — admins may supply mixed case in their UI.
|
||||
var affiliateCodeValidChar = func() [256]bool {
|
||||
var tbl [256]bool
|
||||
for c := byte('A'); c <= 'Z'; c++ {
|
||||
tbl[c] = true
|
||||
}
|
||||
for c := byte('0'); c <= '9'; c++ {
|
||||
tbl[c] = true
|
||||
}
|
||||
tbl['_'] = true
|
||||
tbl['-'] = true
|
||||
return tbl
|
||||
}()
|
||||
|
||||
// isValidAffiliateCodeFormat validates code format for both binding (user input)
|
||||
// and admin updates. Caller is expected to upper-case the input first.
|
||||
func isValidAffiliateCodeFormat(code string) bool {
|
||||
if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(code); i++ {
|
||||
if !affiliateCodeValidChar[code[i]] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type AffiliateSummary struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
AffCodeCustom bool `json:"aff_code_custom"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AffiliateInvitee struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
}
|
||||
|
||||
type AffiliateDetail struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
|
||||
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
|
||||
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
|
||||
EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
|
||||
Invitees []AffiliateInvitee `json:"invitees"`
|
||||
}
|
||||
|
||||
type AffiliateRepository interface {
|
||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
|
||||
|
||||
// 管理端:用户级专属配置
|
||||
UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
|
||||
ResetUserAffCode(ctx context.Context, userID int64) (string, error)
|
||||
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
||||
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
||||
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
||||
}
|
||||
|
||||
// AffiliateAdminFilter 列表筛选条件
|
||||
type AffiliateAdminFilter struct {
|
||||
Search string
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// AffiliateAdminEntry 专属用户列表条目
|
||||
type AffiliateAdminEntry struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
AffCodeCustom bool `json:"aff_code_custom"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
}
|
||||
|
||||
type AffiliateService struct {
|
||||
repo AffiliateRepository
|
||||
settingService *SettingService
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
|
||||
return &AffiliateService{
|
||||
repo: repo,
|
||||
settingService: settingService,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
|
||||
func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
|
||||
if s == nil || s.settingService == nil {
|
||||
return AffiliateEnabledDefault
|
||||
}
|
||||
return s.settingService.IsAffiliateEnabled(ctx)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
|
||||
if userID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.EnsureUserAffiliate(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
|
||||
// Lazy thaw: move any matured frozen quota to available before reading.
|
||||
if s != nil && s.repo != nil {
|
||||
// best-effort: thaw failure is non-fatal
|
||||
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
|
||||
}
|
||||
|
||||
summary, err := s.EnsureUserAffiliate(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invitees, err := s.listInvitees(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AffiliateDetail{
|
||||
UserID: summary.UserID,
|
||||
AffCode: summary.AffCode,
|
||||
InviterID: summary.InviterID,
|
||||
AffCount: summary.AffCount,
|
||||
AffQuota: summary.AffQuota,
|
||||
AffFrozenQuota: summary.AffFrozenQuota,
|
||||
AffHistoryQuota: summary.AffHistoryQuota,
|
||||
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
|
||||
Invitees: invitees,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
|
||||
code := strings.ToUpper(strings.TrimSpace(rawCode))
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
// 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
|
||||
if !s.IsEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
if !isValidAffiliateCodeFormat(code) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
|
||||
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if selfSummary.InviterID != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAffiliateProfileNotFound) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
return err
|
||||
}
|
||||
if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
|
||||
bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bound {
|
||||
return ErrAffiliateAlreadyBound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
|
||||
return 0, nil
|
||||
}
|
||||
// 总开关关闭时,新充值不再产生返利
|
||||
if !s.IsEnabled(ctx) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 加载邀请人 profile,优先使用专属比例(覆盖全局)
|
||||
inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// 有效期检查:超过返利有效期后不再产生返利
|
||||
if s.settingService != nil {
|
||||
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
|
||||
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
|
||||
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
|
||||
if rebate <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 单人上限检查:精确截断到剩余额度
|
||||
if s.settingService != nil {
|
||||
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
|
||||
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if existing >= perInviteeCap {
|
||||
return 0, nil
|
||||
}
|
||||
if remaining := perInviteeCap - existing; rebate > remaining {
|
||||
rebate = roundTo(remaining, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var freezeHours int
|
||||
if s.settingService != nil {
|
||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !applied {
|
||||
return 0, nil
|
||||
}
|
||||
return rebate, nil
|
||||
}
|
||||
|
||||
// resolveRebateRatePercent returns the inviter's exclusive rate when set,
|
||||
// otherwise the global setting value (clamped to [Min, Max]).
|
||||
func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
|
||||
if inviter != nil && inviter.AffRebateRatePercent != nil {
|
||||
v := *inviter.AffRebateRatePercent
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
return s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
return clampAffiliateRebateRate(v)
|
||||
}
|
||||
return s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
|
||||
// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
|
||||
// returning the documented default when SettingService is unavailable.
|
||||
func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
|
||||
if s == nil || s.settingService == nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
return s.settingService.GetAffiliateRebateRatePercent(ctx)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
|
||||
transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if transferred > 0 {
|
||||
s.invalidateAffiliateCaches(ctx, userID)
|
||||
}
|
||||
return transferred, balance, nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range invitees {
|
||||
invitees[i].Email = maskEmail(invitees[i].Email)
|
||||
}
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func roundTo(v float64, scale int) float64 {
|
||||
factor := math.Pow10(scale)
|
||||
return math.Round(v*factor) / factor
|
||||
}
|
||||
|
||||
func maskEmail(email string) string {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
at := strings.Index(email, "@")
|
||||
if at <= 0 || at >= len(email)-1 {
|
||||
return "***"
|
||||
}
|
||||
|
||||
local := email[:at]
|
||||
domain := email[at+1:]
|
||||
dot := strings.LastIndex(domain, ".")
|
||||
|
||||
maskedLocal := maskSegment(local)
|
||||
if dot <= 0 || dot >= len(domain)-1 {
|
||||
return maskedLocal + "@" + maskSegment(domain)
|
||||
}
|
||||
|
||||
domainName := domain[:dot]
|
||||
tld := domain[dot:]
|
||||
return maskedLocal + "@" + maskSegment(domainName) + tld
|
||||
}
|
||||
|
||||
func maskSegment(s string) string {
|
||||
r := []rune(s)
|
||||
if len(r) == 0 {
|
||||
return "***"
|
||||
}
|
||||
if len(r) == 1 {
|
||||
return string(r[0]) + "***"
|
||||
}
|
||||
return string(r[0]) + "***"
|
||||
}
|
||||
|
||||
func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService != nil {
|
||||
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Admin: 专属配置管理
|
||||
// =========================
|
||||
|
||||
// validateExclusiveRate ensures a per-user override is finite and within
|
||||
// [Min, Max]. nil is always valid (means "clear / fall back to global").
|
||||
func validateExclusiveRate(ratePercent *float64) error {
|
||||
if ratePercent == nil {
|
||||
return nil
|
||||
}
|
||||
v := *ratePercent
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
|
||||
}
|
||||
if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
|
||||
return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
|
||||
func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
code := strings.ToUpper(strings.TrimSpace(rawCode))
|
||||
if !isValidAffiliateCodeFormat(code) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
return s.repo.UpdateUserAffCode(ctx, userID, code)
|
||||
}
|
||||
|
||||
// AdminResetUserAffCode 重置用户邀请码为系统随机码。
|
||||
func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ResetUserAffCode(ctx, userID)
|
||||
}
|
||||
|
||||
// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
|
||||
func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
if err := validateExclusiveRate(ratePercent); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
|
||||
}
|
||||
|
||||
// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
|
||||
func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
if err := validateExclusiveRate(ratePercent); err != nil {
|
||||
return err
|
||||
}
|
||||
cleaned := make([]int64, 0, len(userIDs))
|
||||
for _, uid := range userIDs {
|
||||
if uid > 0 {
|
||||
cleaned = append(cleaned, uid)
|
||||
}
|
||||
}
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
|
||||
}
|
||||
|
||||
// AdminListCustomUsers 列出有专属配置的用户。
|
||||
func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
||||
}
|
||||
131
backend/internal/service/affiliate_service_test.go
Normal file
131
backend/internal/service/affiliate_service_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
|
||||
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
|
||||
// global rate, and that out-of-range exclusive rates are clamped silently.
|
||||
//
|
||||
// SettingService is left nil here so globalRebateRatePercent returns the
|
||||
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
|
||||
// fallback path without spinning up a settings stub.
|
||||
func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &AffiliateService{}
|
||||
|
||||
// nil exclusive rate → falls back to global default (20%)
|
||||
require.InDelta(t, AffiliateRebateRateDefault,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
|
||||
|
||||
// exclusive rate set → overrides global
|
||||
rate := 50.0
|
||||
require.InDelta(t, 50.0,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
|
||||
|
||||
// exclusive rate 0 → returns 0 (no rebate, intentional)
|
||||
zero := 0.0
|
||||
require.InDelta(t, 0.0,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
|
||||
|
||||
// exclusive rate above max → clamped to Max
|
||||
tooHigh := 250.0
|
||||
require.InDelta(t, AffiliateRebateRateMax,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
|
||||
|
||||
// exclusive rate below min → clamped to Min
|
||||
tooLow := -5.0
|
||||
require.InDelta(t, AffiliateRebateRateMin,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
|
||||
}
|
||||
|
||||
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
|
||||
// safely handles a nil settingService dependency by returning the default
|
||||
// (off). This protects callers from nil-pointer crashes in misconfigured
|
||||
// environments.
|
||||
func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &AffiliateService{}
|
||||
require.False(t, svc.IsEnabled(context.Background()))
|
||||
require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
|
||||
}
|
||||
|
||||
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
|
||||
// admin-facing rate setters: nil is always valid (clear), in-range values
|
||||
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
|
||||
func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, validateExclusiveRate(nil))
|
||||
|
||||
for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
|
||||
v := v
|
||||
require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
|
||||
}
|
||||
|
||||
for _, v := range []float64{-0.01, 100.01, -100, 200} {
|
||||
v := v
|
||||
require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
|
||||
}
|
||||
|
||||
nan := math.NaN()
|
||||
require.Error(t, validateExclusiveRate(&nan))
|
||||
posInf := math.Inf(1)
|
||||
require.Error(t, validateExclusiveRate(&posInf))
|
||||
negInf := math.Inf(-1)
|
||||
require.Error(t, validateExclusiveRate(&negInf))
|
||||
}
|
||||
|
||||
func TestMaskEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
|
||||
require.Equal(t, "x***@d***", maskEmail("x@domain"))
|
||||
require.Equal(t, "", maskEmail(""))
|
||||
}
|
||||
|
||||
func TestIsValidAffiliateCodeFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 邀请码格式校验同时服务于:
|
||||
// 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
|
||||
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
|
||||
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"valid canonical 12-char", "ABCDEFGHJKLM", true},
|
||||
{"valid all digits 2-9", "234567892345", true},
|
||||
{"valid mixed", "A2B3C4D5E6F7", true},
|
||||
{"valid admin custom short", "VIP1", true},
|
||||
{"valid admin custom with hyphen", "NEW-USER", true},
|
||||
{"valid admin custom with underscore", "VIP_2026", true},
|
||||
{"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
|
||||
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
|
||||
{"letter I now allowed", "IBCDEFGHJKLM", true},
|
||||
{"letter O now allowed", "OBCDEFGHJKLM", true},
|
||||
{"digit 0 now allowed", "0BCDEFGHJKLM", true},
|
||||
{"digit 1 now allowed", "1BCDEFGHJKLM", true},
|
||||
{"too short (3 chars)", "ABC", false},
|
||||
{"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
|
||||
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
|
||||
{"empty", "", false},
|
||||
{"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
|
||||
{"ascii punctuation .", "ABCDEFGHJK.M", false},
|
||||
{"whitespace", "ABCDEFGHJK M", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
user *User,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
affiliateCode string,
|
||||
) error {
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ type AuthService struct {
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
affiliateService *AffiliateService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
@@ -98,6 +99,7 @@ func NewAuthService(
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
affiliateService *AffiliateService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
@@ -110,6 +112,7 @@ func NewAuthService(
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
affiliateService: affiliateService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, "email", true)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
if s.affiliateService != nil {
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
|
||||
// 邀请返利码绑定失败不影响注册,只记录日志
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
@@ -549,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -652,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
}
|
||||
} else {
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@@ -669,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
@@ -763,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
}
|
||||
}
|
||||
|
||||
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
|
||||
// for an OAuth-registered user. Failures are logged but never block registration.
|
||||
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
|
||||
if s.affiliateService == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
|
||||
@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
||||
}
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
|
||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, cache)
|
||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
|
||||
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
|
||||
ID: 41,
|
||||
|
||||
@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
values: settings,
|
||||
}, cfg)
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
|
||||
@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
nil,
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
nil, // affiliateService
|
||||
)
|
||||
}
|
||||
|
||||
@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
@@ -621,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
@@ -657,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.Equal(t, existing.ID, user.ID)
|
||||
|
||||
@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
||||
nil, // emailQueueService
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
nil, // affiliateService
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,19 @@ const (
|
||||
RoleUser = domain.RoleUser
|
||||
)
|
||||
|
||||
// Affiliate rebate settings
|
||||
const (
|
||||
AffiliateRebateRateDefault = 20.0
|
||||
AffiliateRebateRateMin = 0.0
|
||||
AffiliateRebateRateMax = 100.0
|
||||
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
|
||||
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
|
||||
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
|
||||
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
|
||||
AffiliateRebateDurationDaysMax = 3650 // ~10 年
|
||||
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = domain.PlatformAnthropic
|
||||
@@ -87,6 +100,11 @@ const (
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
|
||||
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
|
||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
|
||||
@@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
|
||||
system := gjson.GetBytes(upstream.lastBody, "system")
|
||||
require.True(t, system.Exists())
|
||||
require.True(t, system.IsArray(), "system should be an array")
|
||||
require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
|
||||
require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
|
||||
arr := system.Array()
|
||||
require.Len(t, arr, 2, "system array should have billing block + cc prompt block")
|
||||
|
||||
require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:")
|
||||
require.Contains(t, arr[0].Get("text").String(), "cc_version=")
|
||||
|
||||
require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String())
|
||||
require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String())
|
||||
|
||||
// 原始 system prompt 应迁移至 messages 中
|
||||
messages := gjson.GetBytes(upstream.lastBody, "messages")
|
||||
|
||||
98
backend/internal/service/gateway_billing_block.go
Normal file
98
backend/internal/service/gateway_billing_block.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。
|
||||
//
|
||||
// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致;
|
||||
// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致,
|
||||
// 进一步触发 Anthropic 的第三方检测。
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法:
|
||||
//
|
||||
// 1. 取 messages 中第一条 role=user 的纯文本(首块 text)
|
||||
// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐)
|
||||
// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符
|
||||
//
|
||||
// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint,与官方 CLI 字节对齐。
|
||||
// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。
|
||||
func computeClaudeCodeFingerprint(body []byte, version string) string {
|
||||
firstText := extractFirstUserText(body)
|
||||
indices := []int{4, 7, 20}
|
||||
chars := make([]byte, 0, 3)
|
||||
for _, i := range indices {
|
||||
if i < len(firstText) {
|
||||
chars = append(chars, firstText[i])
|
||||
} else {
|
||||
chars = append(chars, '0')
|
||||
}
|
||||
}
|
||||
sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version))
|
||||
return hex.EncodeToString(sum[:])[:3]
|
||||
}
|
||||
|
||||
// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。
|
||||
// 兼容 string 和 []block 两种 content 格式。
|
||||
func extractFirstUserText(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
first := ""
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if msg.Get("role").String() != "user" {
|
||||
return true
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
first = content.String()
|
||||
return false
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
if block.Get("type").String() == "text" {
|
||||
first = block.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return false
|
||||
}
|
||||
return false
|
||||
})
|
||||
return first
|
||||
}
|
||||
|
||||
// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。
|
||||
//
|
||||
// 形态严格对齐真实 Claude Code CLI:
|
||||
//
|
||||
// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"}
|
||||
//
|
||||
// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段
|
||||
// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。
|
||||
//
|
||||
// 此 block 不带 cache_control(与真实 CLI 一致;cache breakpoint 由后续的
|
||||
// Claude Code prompt block 承担)。
|
||||
func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) {
|
||||
if cliVersion == "" {
|
||||
return nil, fmt.Errorf("cliVersion required")
|
||||
}
|
||||
fp := computeClaudeCodeFingerprint(body, cliVersion)
|
||||
text := fmt.Sprintf(
|
||||
"x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;",
|
||||
cliVersion, fp,
|
||||
)
|
||||
return json.Marshal(map[string]string{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
}
|
||||
@@ -41,12 +41,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
|
||||
resultStr := string(result)
|
||||
|
||||
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||
require.NotContains(t, resultStr, `"temperature"`)
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`)
|
||||
require.Contains(t, resultStr, `"temperature":0.2`)
|
||||
require.NotContains(t, resultStr, `"tool_choice"`)
|
||||
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
|
||||
require.Contains(t, resultStr, `"tools":[]`)
|
||||
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
|
||||
require.Contains(t, resultStr, `"max_tokens":128000`)
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
|
||||
|
||||
@@ -85,15 +85,16 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
return nil, fmt.Errorf("marshal anthropic request: %w", err)
|
||||
}
|
||||
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts
|
||||
isClaudeCode := false // CC API is never Claude Code
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts.
|
||||
// Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
|
||||
// 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
|
||||
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
|
||||
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
|
||||
isClaudeCode := false
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
|
||||
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
|
||||
}
|
||||
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
|
||||
}
|
||||
|
||||
// 7. Enforce cache_control block limit
|
||||
@@ -312,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, ccResp)
|
||||
// Marshal then bytes-replace so tool name mapping is reversed at byte level
|
||||
// (parity with Parrot non-stream flow that marshals → restore → emit).
|
||||
if respBytes, err := json.Marshal(ccResp); err == nil {
|
||||
respBytes = reverseToolNamesIfPresent(c, respBytes)
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
|
||||
} else {
|
||||
c.JSON(http.StatusOK, ccResp)
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
@@ -383,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
// Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
|
||||
// c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。
|
||||
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||
if _, err := fmt.Fprint(c.Writer, out); err != nil {
|
||||
return true // client disconnected
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -82,15 +82,16 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
return nil, fmt.Errorf("marshal anthropic request: %w", err)
|
||||
}
|
||||
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
|
||||
isClaudeCode := false // Responses API is never Claude Code
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints).
|
||||
// OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
|
||||
// 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
|
||||
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
|
||||
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
|
||||
isClaudeCode := false
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
|
||||
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
|
||||
}
|
||||
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
|
||||
}
|
||||
|
||||
// 7. Enforce cache_control block limit
|
||||
@@ -331,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, responsesResp)
|
||||
if respBytes, err := json.Marshal(responsesResp); err == nil {
|
||||
respBytes = reverseToolNamesIfPresent(c, respBytes)
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
|
||||
} else {
|
||||
c.JSON(http.StatusOK, responsesResp)
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
@@ -419,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||
if _, err := fmt.Fprint(c.Writer, out); err != nil {
|
||||
logger.L().Info("forward_as_responses stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
@@ -439,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
|
||||
fmt.Fprint(c.Writer, out) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
141
backend/internal/service/gateway_messages_cache.go
Normal file
141
backend/internal/service/gateway_messages_cache.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
|
||||
// 与 Parrot _strip_message_cache_control 语义一致。
|
||||
//
|
||||
// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
|
||||
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
|
||||
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
|
||||
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
|
||||
func stripMessageCacheControl(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
msgIdx := -1
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
msgIdx++
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
blockIdx := -1
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
blockIdx++
|
||||
if !block.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
|
||||
if next, err := sjson.DeleteBytes(body, path); err == nil {
|
||||
body = next
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
return body
|
||||
}
|
||||
|
||||
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
|
||||
// 1. 最后一条 message
|
||||
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
|
||||
//
|
||||
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
|
||||
// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。
|
||||
//
|
||||
// cache_control ttl 策略:
|
||||
// - 若目标 block 已有 cache_control.ttl → 不覆盖
|
||||
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
|
||||
//
|
||||
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
|
||||
func addMessageCacheBreakpoints(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
arr := messages.Array()
|
||||
if len(arr) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
|
||||
|
||||
if len(arr) >= 4 {
|
||||
userCount := 0
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
if arr[i].Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
userCount++
|
||||
if userCount == 2 {
|
||||
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
|
||||
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
|
||||
// (对齐 Parrot _inject_cache_on_msg 的行为)。
|
||||
//
|
||||
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
|
||||
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
|
||||
content := msg.Get("content")
|
||||
|
||||
if content.Type == gjson.String {
|
||||
text := content.String()
|
||||
blockRaw := fmt.Sprintf(
|
||||
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
|
||||
mustJSONString(text), claude.DefaultCacheControlTTL,
|
||||
)
|
||||
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
|
||||
body = next
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
if !content.IsArray() {
|
||||
return body
|
||||
}
|
||||
contentArr := content.Array()
|
||||
if len(contentArr) == 0 {
|
||||
return body
|
||||
}
|
||||
lastBlockIdx := len(contentArr) - 1
|
||||
lastBlock := contentArr[lastBlockIdx]
|
||||
|
||||
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
|
||||
return body
|
||||
}
|
||||
|
||||
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
|
||||
existingCC := lastBlock.Get("cache_control")
|
||||
if existingCC.Exists() {
|
||||
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
|
||||
body = next
|
||||
}
|
||||
return body
|
||||
}
|
||||
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
|
||||
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
|
||||
body = next
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号),
|
||||
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
|
||||
func mustJSONString(s string) string {
|
||||
return fmt.Sprintf("%q", s)
|
||||
}
|
||||
@@ -9,6 +9,11 @@ import (
|
||||
)
|
||||
|
||||
func TestIsClaudeCodeClient(t *testing.T) {
|
||||
// 合法的 legacy 格式 metadata.user_id(64位 hex + account uuid + session uuid)
|
||||
legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
|
||||
// 合法的 JSON 格式 metadata.user_id(2.1.78+ 版本)
|
||||
jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
@@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Claude Code client",
|
||||
name: "Claude Code client with legacy user_id",
|
||||
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
|
||||
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
||||
metadataUserID: legacyUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code without version suffix",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "session_abc",
|
||||
name: "Claude Code client with JSON user_id",
|
||||
userAgent: "claude-cli/2.1.92 (external, cli)",
|
||||
metadataUserID: jsonUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code case insensitive UA",
|
||||
userAgent: "Claude-CLI/2.0.0",
|
||||
metadataUserID: legacyUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
@@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) {
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent",
|
||||
name: "Claude CLI UA with invalid user_id format",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "fake-user-id-12345",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent with valid user_id",
|
||||
userAgent: "curl/7.68.0",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Empty user agent",
|
||||
userAgent: "",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Similar but not Claude CLI",
|
||||
userAgent: "claude-api/1.0.0",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Opencode spoofing UA with arbitrary user_id",
|
||||
userAgent: "claude-cli/2.1.92",
|
||||
metadataUserID: "session_abc",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
@@ -378,16 +401,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
|
||||
// system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态:
|
||||
// [0] billing attribution block (x-anthropic-billing-header: cc_version=...;)
|
||||
// [1] Claude Code prompt block (带 cache_control)
|
||||
systemArr, ok := parsed["system"].([]any)
|
||||
require.True(t, ok, "system should be an array, got %T", parsed["system"])
|
||||
require.Len(t, systemArr, 1, "system array should have exactly 1 block")
|
||||
systemBlock, ok := systemArr[0].(map[string]any)
|
||||
require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)")
|
||||
|
||||
billingBlock, ok := systemArr[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", billingBlock["type"])
|
||||
require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:")
|
||||
require.Contains(t, billingBlock["text"], "cc_version=")
|
||||
require.Contains(t, billingBlock["text"], "cc_entrypoint=cli")
|
||||
require.Contains(t, billingBlock["text"], "cch=00000")
|
||||
|
||||
systemBlock, ok := systemArr[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", systemBlock["type"])
|
||||
require.Equal(t, tt.wantSystemText, systemBlock["text"])
|
||||
cc, ok := systemBlock["cache_control"].(map[string]any)
|
||||
require.True(t, ok, "system block should have cache_control")
|
||||
require.True(t, ok, "cc prompt block should have cache_control")
|
||||
require.Equal(t, "ephemeral", cc["type"])
|
||||
|
||||
// 检查 messages
|
||||
|
||||
@@ -119,7 +119,7 @@ func openAIStreamEventIsTerminal(data string) bool {
|
||||
return true
|
||||
}
|
||||
switch gjson.Get(trimmed, "type").String() {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -329,7 +329,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@@ -850,6 +850,7 @@ func (s *GatewayService) hashContent(content string) string {
|
||||
|
||||
type anthropicCacheControlPayload struct {
|
||||
Type string `json:"type"`
|
||||
TTL string `json:"ttl,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicSystemTextBlockPayload struct {
|
||||
@@ -898,7 +899,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b
|
||||
Text: text,
|
||||
}
|
||||
if includeCacheControl {
|
||||
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
|
||||
block.CacheControl = &anthropicCacheControlPayload{
|
||||
Type: "ephemeral",
|
||||
TTL: claude.DefaultCacheControlTTL,
|
||||
}
|
||||
}
|
||||
return json.Marshal(block)
|
||||
}
|
||||
@@ -1074,19 +1078,52 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "temperature").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
|
||||
// temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。
|
||||
// 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。
|
||||
// 策略:客户端传了什么就透传;没传则补默认 1。
|
||||
if !gjson.GetBytes(out, "temperature").Exists() {
|
||||
if next, ok := setJSONValueBytes(out, "temperature", 1); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
|
||||
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
|
||||
if !gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动
|
||||
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
|
||||
// 客户端显式传了就透传;否则按 CLI 行为补齐。
|
||||
if !gjson.GetBytes(out, "context_management").Exists() {
|
||||
thinkingType := gjson.GetBytes(out, "thinking.type").String()
|
||||
if thinkingType == "enabled" || thinkingType == "adaptive" {
|
||||
const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}`
|
||||
if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tool_choice:与 Parrot 对齐,不再无条件删除。
|
||||
// - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由
|
||||
// applyToolNameRewriteToBody 同步映射为假名
|
||||
// - 其他形态(auto/any/none)原样透传
|
||||
// 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除
|
||||
if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
|
||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body, modelID
|
||||
}
|
||||
@@ -1128,6 +1165,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
||||
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
|
||||
}
|
||||
|
||||
// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号"
|
||||
// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。
|
||||
//
|
||||
// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode +
|
||||
// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层
|
||||
// (ForwardAsChatCompletions / ForwardAsResponses) 复用。
|
||||
//
|
||||
// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加),
|
||||
// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic
|
||||
// 第三方检测";那条注释就是本函数存在的根因。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。
|
||||
// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。
|
||||
// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。
|
||||
// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。
|
||||
// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。
|
||||
//
|
||||
// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。
|
||||
func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
systemRaw any,
|
||||
model string,
|
||||
) []byte {
|
||||
if account == nil || !account.IsOAuth() || len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
systemRewritten := false
|
||||
if !strings.Contains(strings.ToLower(model), "haiku") {
|
||||
body = rewriteSystemForNonClaudeCode(body, systemRaw)
|
||||
systemRewritten = true
|
||||
}
|
||||
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
|
||||
if s.identityService != nil && c != nil && c.Request != nil {
|
||||
if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil {
|
||||
mimicMPT := false
|
||||
if s.settingService != nil {
|
||||
_, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
}
|
||||
if !mimicMPT {
|
||||
if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" {
|
||||
normalizeOpts.injectMetadata = true
|
||||
normalizeOpts.metadataUserID = uid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
|
||||
|
||||
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
||||
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
|
||||
// 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性)
|
||||
// 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn)
|
||||
// 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
|
||||
// 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
|
||||
body = stripMessageCacheControl(body)
|
||||
body = addMessageCacheBreakpoints(body)
|
||||
|
||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||
body = applyToolNameRewriteToBody(body, rw)
|
||||
if c != nil {
|
||||
c.Set(toolNameRewriteKey, rw)
|
||||
}
|
||||
} else {
|
||||
body = applyToolsLastCacheBreakpoint(body)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体,
|
||||
// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。
|
||||
//
|
||||
// 与 buildOAuthMetadataUserID 的唯一区别:
|
||||
// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。
|
||||
// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID
|
||||
// 自行决定是否覆盖)。
|
||||
func (s *GatewayService) buildOAuthMetadataUserIDFromBody(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
fp *Fingerprint,
|
||||
body []byte,
|
||||
) string {
|
||||
_ = ctx
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(account.GetClaudeUserID())
|
||||
if userID == "" && fp != nil {
|
||||
userID = fp.ClientID
|
||||
}
|
||||
if userID == "" {
|
||||
userID = generateClientID()
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
if hash := hashBodyForSessionSeed(body); hash != "" {
|
||||
sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash))
|
||||
}
|
||||
|
||||
var uaVersion string
|
||||
if fp != nil {
|
||||
uaVersion = ExtractCLIVersion(fp.UserAgent)
|
||||
}
|
||||
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
|
||||
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
|
||||
}
|
||||
|
||||
// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。
|
||||
// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。
|
||||
func hashBodyForSessionSeed(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(body)
|
||||
return fmt.Sprintf("%x", sum[:16])
|
||||
}
|
||||
|
||||
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
|
||||
func GenerateSessionUUID(seed string) string {
|
||||
return generateSessionUUID(seed)
|
||||
@@ -3543,23 +3709,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
|
||||
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
|
||||
// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。
|
||||
// 判定条件:
|
||||
// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感)
|
||||
// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式)
|
||||
//
|
||||
// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA
|
||||
// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID
|
||||
// 验证格式才能确认是真正的 Claude Code 客户端。
|
||||
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
if metadataUserID == "" {
|
||||
if !claudeCliUserAgentRe.MatchString(userAgent) {
|
||||
return false
|
||||
}
|
||||
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||
}
|
||||
|
||||
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return true
|
||||
}
|
||||
if parsed == nil || c == nil {
|
||||
return false
|
||||
}
|
||||
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
return ParseMetadataUserID(metadataUserID) != nil
|
||||
}
|
||||
|
||||
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
||||
@@ -3738,17 +3900,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
|
||||
originalSystemText = strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
|
||||
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
|
||||
// 使用 string 格式会被 Anthropic 检测为第三方应用。
|
||||
claudeCodeSystemBlock := []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{"type": "ephemeral"},
|
||||
},
|
||||
// 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态:
|
||||
// [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;)
|
||||
// [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点)
|
||||
//
|
||||
// billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的
|
||||
// signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload
|
||||
// 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。
|
||||
billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion)
|
||||
ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
|
||||
if billingErr != nil || ccErr != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr)
|
||||
return body
|
||||
}
|
||||
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
|
||||
out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock}))
|
||||
if !ok {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
|
||||
return body
|
||||
@@ -3985,15 +4150,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
})
|
||||
}
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
// Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。
|
||||
// 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header,
|
||||
// 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略
|
||||
// (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token
|
||||
// 最低缓存门槛,导致系统级缓存失效)。
|
||||
//
|
||||
// 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。
|
||||
isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
|
||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||
// 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code
|
||||
// 风格的 system prompt)。原因:第三方工具(opencode 等)会发 "You are Claude
|
||||
// Code..." system prompt 但缺少 billing attribution block,导致 Anthropic
|
||||
// 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。
|
||||
// Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。
|
||||
systemRewritten := false
|
||||
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||
if !strings.Contains(strings.ToLower(reqModel), "haiku") {
|
||||
body = rewriteSystemForNonClaudeCode(body, parsed.System)
|
||||
systemRewritten = true
|
||||
}
|
||||
@@ -4017,6 +4191,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
|
||||
// D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
||||
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
|
||||
// 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
|
||||
body = stripMessageCacheControl(body)
|
||||
body = addMessageCacheBreakpoints(body)
|
||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||
body = applyToolNameRewriteToBody(body, rw)
|
||||
c.Set(toolNameRewriteKey, rw)
|
||||
} else {
|
||||
body = applyToolsLastCacheBreakpoint(body)
|
||||
}
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
@@ -4955,7 +5141,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, err := io.WriteString(w, line); err != nil {
|
||||
restored := string(reverseToolNamesIfPresent(c, []byte(line)))
|
||||
if _, err := io.WriteString(w, restored); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else if _, err := io.WriteString(w, "\n"); err != nil {
|
||||
@@ -5125,6 +5312,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
body = reverseToolNamesIfPresent(c, body)
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
}
|
||||
@@ -5580,13 +5768,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
setHeaderRaw(req.Header, "x-api-key", token)
|
||||
}
|
||||
|
||||
// 白名单透传headers(恢复真实 wire casing)
|
||||
for key, values := range clientHeaders {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if allowedHeaders[lowerKey] {
|
||||
wireKey := resolveWireCasing(key)
|
||||
for _, v := range values {
|
||||
addHeaderRaw(req.Header, wireKey, v)
|
||||
// 白名单透传 headers
|
||||
// OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。
|
||||
// Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。
|
||||
// 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent /
|
||||
// x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。
|
||||
if tokenType != "oauth" || !mimicClaudeCode {
|
||||
for key, values := range clientHeaders {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if allowedHeaders[lowerKey] {
|
||||
wireKey := resolveWireCasing(key)
|
||||
for _, v := range values {
|
||||
addHeaderRaw(req.Header, wireKey, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5627,7 +5821,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// Haiku models are exempt from third-party detection and don't need it.
|
||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
if !strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
requiredBetas = claude.FullClaudeCodeMimicryBetas()
|
||||
}
|
||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
|
||||
} else {
|
||||
@@ -6099,6 +6293,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
|
||||
if isStream {
|
||||
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
|
||||
}
|
||||
// Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。
|
||||
// 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。
|
||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" {
|
||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString())
|
||||
}
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -6864,7 +7063,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
for _, block := range outputBlocks {
|
||||
if !clientDisconnected {
|
||||
if _, werr := fmt.Fprint(w, block); werr != nil {
|
||||
restored := reverseToolNamesIfPresent(c, []byte(block))
|
||||
if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
break
|
||||
@@ -7206,6 +7406,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
body = reverseToolNamesIfPresent(c, body)
|
||||
|
||||
// 写入响应
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
@@ -8194,12 +8396,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
// Pre-filter: strip empty text blocks to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
|
||||
body = stripMessageCacheControl(body)
|
||||
body = addMessageCacheBreakpoints(body)
|
||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||
body = applyToolNameRewriteToBody(body, rw)
|
||||
} else {
|
||||
body = applyToolsLastCacheBreakpoint(body)
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
|
||||
@@ -8623,7 +8833,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
applyClaudeCodeMimicHeaders(req, false)
|
||||
|
||||
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
|
||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
|
||||
} else {
|
||||
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
||||
|
||||
313
backend/internal/service/gateway_tool_rewrite.go
Normal file
313
backend/internal/service/gateway_tool_rewrite.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
|
||||
// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
|
||||
const toolNameRewriteKey = "claude_tool_name_rewrite"
|
||||
|
||||
// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
|
||||
// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
|
||||
var staticToolNameRewrites = map[string]string{
|
||||
"sessions_": "cc_sess_",
|
||||
"session_": "cc_ses_",
|
||||
}
|
||||
|
||||
// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
|
||||
// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
|
||||
var fakeToolNamePrefixes = []string{
|
||||
"analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
|
||||
"process_", "query_", "render_", "resolve_", "sync_", "update_",
|
||||
"validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
|
||||
"review_", "search_", "transform_", "handle_", "invoke_", "notify_",
|
||||
}
|
||||
|
||||
// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。
|
||||
// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
|
||||
const dynamicToolMapThreshold = 5
|
||||
|
||||
// ToolNameRewrite 是单次请求内的工具名混淆映射。
|
||||
// - Forward: real → fake,请求阶段在 body 上应用。
|
||||
// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。
|
||||
//
|
||||
// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
|
||||
// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
|
||||
// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
|
||||
type ToolNameRewrite struct {
|
||||
Forward map[string]string
|
||||
Reverse map[string]string
|
||||
ReverseOrdered [][2]string
|
||||
}
|
||||
|
||||
// buildDynamicToolMap 构造 tools 的动态假名映射。
|
||||
//
|
||||
// 与 Parrot _build_dynamic_tool_map 语义等价:
|
||||
// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback)
|
||||
// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
|
||||
//
|
||||
// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
|
||||
// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留:
|
||||
// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
|
||||
// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。
|
||||
func buildDynamicToolMap(toolNames []string) map[string]string {
|
||||
if len(toolNames) <= dynamicToolMapThreshold {
|
||||
return nil
|
||||
}
|
||||
h := fnv.New64a()
|
||||
for i, n := range toolNames {
|
||||
if i > 0 {
|
||||
_, _ = h.Write([]byte{0})
|
||||
}
|
||||
_, _ = h.Write([]byte(n))
|
||||
}
|
||||
rng := rand.New(rand.NewSource(int64(h.Sum64())))
|
||||
|
||||
available := make([]string, len(fakeToolNamePrefixes))
|
||||
copy(available, fakeToolNamePrefixes)
|
||||
rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
|
||||
|
||||
mapping := make(map[string]string, len(toolNames))
|
||||
for i, name := range toolNames {
|
||||
prefix := available[i%len(available)]
|
||||
headLen := 3
|
||||
if len(name) < 3 {
|
||||
headLen = len(name)
|
||||
}
|
||||
fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
|
||||
mapping[name] = fake
|
||||
}
|
||||
return mapping
|
||||
}
|
||||
|
||||
// sanitizeToolName 把真名转成假名。
|
||||
// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
|
||||
func sanitizeToolName(name string, dynamic map[string]string) string {
|
||||
if dynamic != nil {
|
||||
if fake, ok := dynamic[name]; ok {
|
||||
return fake
|
||||
}
|
||||
}
|
||||
for prefix, replacement := range staticToolNameRewrites {
|
||||
if strings.HasPrefix(name, prefix) {
|
||||
return replacement + name[len(prefix):]
|
||||
}
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// shouldMimicToolName 指示某个 tool 是否需要重命名。
|
||||
// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
|
||||
// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
|
||||
func shouldMimicToolName(toolType string) bool {
|
||||
if toolType == "" || toolType == "function" || toolType == "custom" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite
|
||||
// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
|
||||
//
|
||||
// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
|
||||
func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
mimicableNames := make([]string, 0)
|
||||
toolsArr := tools.Array()
|
||||
for _, t := range toolsArr {
|
||||
if !shouldMimicToolName(t.Get("type").String()) {
|
||||
continue
|
||||
}
|
||||
name := t.Get("name").String()
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
mimicableNames = append(mimicableNames, name)
|
||||
}
|
||||
|
||||
dynamic := buildDynamicToolMap(mimicableNames)
|
||||
|
||||
rw := &ToolNameRewrite{
|
||||
Forward: make(map[string]string),
|
||||
Reverse: make(map[string]string),
|
||||
}
|
||||
for _, name := range mimicableNames {
|
||||
fake := sanitizeToolName(name, dynamic)
|
||||
if fake == name {
|
||||
continue
|
||||
}
|
||||
rw.Forward[name] = fake
|
||||
rw.Reverse[fake] = name
|
||||
}
|
||||
if len(rw.Forward) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
|
||||
for fake, real := range rw.Reverse {
|
||||
rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
|
||||
}
|
||||
sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
|
||||
return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
|
||||
})
|
||||
|
||||
return rw
|
||||
}
|
||||
|
||||
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
|
||||
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
|
||||
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
|
||||
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
|
||||
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
|
||||
//
|
||||
// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
|
||||
// 响应侧 bytes.Replace 会连带还原它们。
|
||||
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
|
||||
if rw == nil || len(rw.Forward) == 0 {
|
||||
body = applyToolsLastCacheBreakpoint(body)
|
||||
return body
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.IsArray() {
|
||||
idx := -1
|
||||
tools.ForEach(func(_, t gjson.Result) bool {
|
||||
idx++
|
||||
if !shouldMimicToolName(t.Get("type").String()) {
|
||||
return true
|
||||
}
|
||||
name := t.Get("name").String()
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
fake, ok := rw.Forward[name]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
|
||||
body = next
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
|
||||
name := tc.Get("name").String()
|
||||
if fake, ok := rw.Forward[name]; ok {
|
||||
if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
|
||||
body = next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body = applyToolsLastCacheBreakpoint(body)
|
||||
return body
|
||||
}
|
||||
|
||||
// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
|
||||
// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
|
||||
// 行为,但 ttl 按本仓规则:
|
||||
// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
|
||||
// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
|
||||
//
|
||||
// 纯副作用函数,tools 不存在或为空数组时 no-op。
|
||||
func applyToolsLastCacheBreakpoint(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return body
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) == 0 {
|
||||
return body
|
||||
}
|
||||
lastIdx := len(arr) - 1
|
||||
existingCC := arr[lastIdx].Get("cache_control")
|
||||
|
||||
if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
|
||||
return body
|
||||
}
|
||||
|
||||
if existingCC.Exists() {
|
||||
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
|
||||
body = next
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
|
||||
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
|
||||
body = next
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
|
||||
// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突
|
||||
// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
|
||||
// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。
|
||||
//
|
||||
// rw 可为 nil;nil 时仍会做静态前缀还原。
|
||||
func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
|
||||
if rw != nil {
|
||||
for _, pair := range rw.ReverseOrdered {
|
||||
fake, real := pair[0], pair[1]
|
||||
if fake == "" || fake == real {
|
||||
continue
|
||||
}
|
||||
data = replaceAllBytes(data, fake, real)
|
||||
}
|
||||
}
|
||||
for prefix, replacement := range staticToolNameRewrites {
|
||||
data = replaceAllBytes(data, replacement, prefix)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
|
||||
func replaceAllBytes(data []byte, from, to string) []byte {
|
||||
if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
|
||||
return data
|
||||
}
|
||||
return []byte(strings.ReplaceAll(string(data), from, to))
|
||||
}
|
||||
|
||||
// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
|
||||
// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。
|
||||
func toolNameRewriteFromContext(c interface {
|
||||
Get(string) (any, bool)
|
||||
}) *ToolNameRewrite {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := c.Get(toolNameRewriteKey)
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
rw, _ := raw.(*ToolNameRewrite)
|
||||
return rw
|
||||
}
|
||||
|
||||
// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
|
||||
// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
|
||||
func reverseToolNamesIfPresent(c interface {
|
||||
Get(string) (any, bool)
|
||||
}, chunk []byte) []byte {
|
||||
rw := toolNameRewriteFromContext(c)
|
||||
if rw == nil && len(staticToolNameRewrites) == 0 {
|
||||
return chunk
|
||||
}
|
||||
return restoreToolNamesInBytes(chunk, rw)
|
||||
}
|
||||
185
backend/internal/service/gateway_tool_rewrite_test.go
Normal file
185
backend/internal/service/gateway_tool_rewrite_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
|
||||
// Parrot 行为:tools 数量 ≤ 5 时不做动态映射。
|
||||
names := []string{"bash", "edit", "read", "write", "search"}
|
||||
require.Nil(t, buildDynamicToolMap(names))
|
||||
}
|
||||
|
||||
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
|
||||
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
|
||||
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
|
||||
a := buildDynamicToolMap(names)
|
||||
b := buildDynamicToolMap(names)
|
||||
require.NotNil(t, a)
|
||||
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
|
||||
require.Len(t, a, 6)
|
||||
for _, name := range names {
|
||||
require.Contains(t, a, name)
|
||||
require.NotEqual(t, name, a[name])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
|
||||
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
|
||||
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
|
||||
require.Equal(t, "bash", sanitizeToolName("bash", nil))
|
||||
}
|
||||
|
||||
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
|
||||
dyn := map[string]string{"sessions_list": "analyze_ses00"}
|
||||
got := sanitizeToolName("sessions_list", dyn)
|
||||
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
|
||||
}
|
||||
|
||||
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
|
||||
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
|
||||
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
|
||||
rw := &ToolNameRewrite{
|
||||
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
|
||||
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
|
||||
}
|
||||
// 手工构造 ReverseOrdered:长的在前
|
||||
rw.ReverseOrdered = [][2]string{
|
||||
{"abc_12_ext", "bar"},
|
||||
{"abc_12", "foo"},
|
||||
}
|
||||
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
|
||||
restored := string(restoreToolNamesInBytes(data, rw))
|
||||
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
|
||||
}
|
||||
|
||||
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
|
||||
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
|
||||
got := string(restoreToolNamesInBytes(data, nil))
|
||||
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
|
||||
}
|
||||
|
||||
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
|
||||
rw := buildToolNameRewriteFromBody(body)
|
||||
require.NotNil(t, rw)
|
||||
require.Contains(t, rw.Forward, "sessions_list")
|
||||
require.Contains(t, rw.Forward, "session_get")
|
||||
// web_search is a server tool, not rewritten
|
||||
require.NotContains(t, rw.Forward, "web_search")
|
||||
|
||||
out := applyToolNameRewriteToBody(body, rw)
|
||||
|
||||
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
|
||||
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
|
||||
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
|
||||
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
|
||||
|
||||
// tool_choice.name rewritten
|
||||
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
|
||||
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
|
||||
}
|
||||
|
||||
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
|
||||
out := applyToolsLastCacheBreakpoint(body)
|
||||
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
|
||||
// First tool untouched
|
||||
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
|
||||
}
|
||||
|
||||
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
|
||||
out := applyToolsLastCacheBreakpoint(body)
|
||||
// User-provided ttl must be preserved.
|
||||
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestStripMessageCacheControl(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
|
||||
out := stripMessageCacheControl(body)
|
||||
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
|
||||
}
|
||||
|
||||
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
out := addMessageCacheBreakpoints(body)
|
||||
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
|
||||
// Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
|
||||
body := []byte(`{"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"q1"}]},
|
||||
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
|
||||
{"role":"user","content":[{"type":"text","text":"q2"}]},
|
||||
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
|
||||
]}`)
|
||||
out := addMessageCacheBreakpoints(body)
|
||||
// 最后一条 assistant 被打断点
|
||||
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
|
||||
// 倒数第二个 user turn = index 0(唯一另一个 user)
|
||||
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
|
||||
// 其他不打断点
|
||||
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
|
||||
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
|
||||
}
|
||||
|
||||
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := addMessageCacheBreakpoints(body)
|
||||
// content 升级成数组
|
||||
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
|
||||
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
|
||||
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
|
||||
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
|
||||
body := []byte(`{"tools":[
|
||||
{"name":"t1","input_schema":{}},
|
||||
{"name":"t2","input_schema":{}},
|
||||
{"name":"t3","input_schema":{}},
|
||||
{"name":"t4","input_schema":{}},
|
||||
{"name":"t5","input_schema":{}},
|
||||
{"name":"t6","input_schema":{}}
|
||||
]}`)
|
||||
rw := buildToolNameRewriteFromBody(body)
|
||||
require.NotNil(t, rw)
|
||||
require.NotEmpty(t, rw.ReverseOrdered)
|
||||
for i := 1; i < len(rw.ReverseOrdered); i++ {
|
||||
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
|
||||
"ReverseOrdered must be sorted by fake-name length descending")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
|
||||
data := []byte("plain text without any tool names")
|
||||
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
|
||||
}
|
||||
|
||||
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
|
||||
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
|
||||
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
|
||||
m := buildDynamicToolMap(names)
|
||||
require.NotNil(t, m)
|
||||
for _, name := range names {
|
||||
fake, ok := m[name]
|
||||
require.True(t, ok)
|
||||
// fake = prefix + head3 + "%02d"
|
||||
// ends with two decimal digits
|
||||
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
|
||||
head := name
|
||||
if len(head) > 3 {
|
||||
head = head[:3]
|
||||
}
|
||||
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.92 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.70.0",
|
||||
StainlessOS: "Linux",
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
@@ -45,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
RequiredImageCapability OpenAIImagesCapability
|
||||
RequireCompact bool
|
||||
ExcludedIDs map[int64]struct{}
|
||||
}
|
||||
|
||||
@@ -258,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select(
|
||||
previousResponseID,
|
||||
req.RequestedModel,
|
||||
req.ExcludedIDs,
|
||||
req.RequireCompact,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
selection = nil
|
||||
}
|
||||
}
|
||||
@@ -348,8 +352,8 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
|
||||
if account == nil {
|
||||
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
|
||||
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -590,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
|
||||
}
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
@@ -630,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
|
||||
}
|
||||
|
||||
loadMap := map[int64]*AccountLoadInfo{}
|
||||
@@ -640,45 +644,14 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}
|
||||
}
|
||||
|
||||
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
if account.Priority < minPriority {
|
||||
minPriority = account.Priority
|
||||
}
|
||||
if account.Priority > maxPriority {
|
||||
maxPriority = account.Priority
|
||||
}
|
||||
if loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = loadInfo.WaitingCount
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
if hasTTFT && ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = ttft, ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if ttft < minTTFT {
|
||||
minTTFT = ttft
|
||||
}
|
||||
if ttft > maxTTFT {
|
||||
maxTTFT = ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
@@ -686,53 +659,183 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
|
||||
// 时作为最后兜底(snapshot 可能已陈旧)。
|
||||
candidates := allCandidates
|
||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
for _, candidate := range allCandidates {
|
||||
if openAICompactSupportTier(candidate.account) == 0 {
|
||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
|
||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
|
||||
topK := s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
|
||||
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
|
||||
candidateCount := len(candidates)
|
||||
loadSkew := 0.0
|
||||
if len(candidates) > 0 {
|
||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
for _, candidate := range candidates {
|
||||
if candidate.account.Priority < minPriority {
|
||||
minPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.account.Priority > maxPriority {
|
||||
maxPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = candidate.loadInfo.WaitingCount
|
||||
}
|
||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if candidate.ttft < minTTFT {
|
||||
minTTFT = candidate.ttft
|
||||
}
|
||||
if candidate.ttft > maxTTFT {
|
||||
maxTTFT = candidate.ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
}
|
||||
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
}
|
||||
|
||||
topK := 0
|
||||
if len(candidates) > 0 {
|
||||
topK = s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
}
|
||||
|
||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 || topK <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupTopK := topK
|
||||
if groupTopK > len(pool) {
|
||||
groupTopK = len(pool)
|
||||
}
|
||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||
}
|
||||
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 {
|
||||
return nil
|
||||
}
|
||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
a, b := ordered[i], ordered[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
switch openAICompactSupportTier(candidate.account) {
|
||||
case 2:
|
||||
supported = append(supported, candidate)
|
||||
case 1:
|
||||
unknown = append(unknown, candidate)
|
||||
}
|
||||
}
|
||||
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
|
||||
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
|
||||
}
|
||||
} else {
|
||||
selectionOrder = buildSelectionOrder(candidates)
|
||||
}
|
||||
if len(selectionOrder) == 0 {
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
|
||||
}
|
||||
|
||||
compactBlocked := false
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
@@ -742,17 +845,25 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
@@ -761,10 +872,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
@@ -905,8 +1016,9 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
requireCompact bool,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
|
||||
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
||||
@@ -917,13 +1029,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredCapability OpenAIImagesCapability,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
|
||||
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
|
||||
if err == nil && selection != nil && selection.Account != nil {
|
||||
return selection, decision, nil
|
||||
}
|
||||
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
|
||||
if requiredCapability == OpenAIImagesCapabilityNative {
|
||||
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic)
|
||||
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
|
||||
}
|
||||
return selection, decision, err
|
||||
}
|
||||
@@ -937,6 +1049,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
requiredImageCapability OpenAIImagesCapability,
|
||||
requireCompact bool,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler(ctx)
|
||||
@@ -945,7 +1058,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||
for {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
|
||||
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
@@ -970,7 +1083,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
|
||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||
for {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
|
||||
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
@@ -1008,6 +1121,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
RequiredImageCapability: requiredImageCapability,
|
||||
RequireCompact: requireCompact,
|
||||
ExcludedIDs: excludedIDs,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown
|
||||
// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := int64(91001)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 71001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{}, // unknown
|
||||
},
|
||||
{
|
||||
ID: 71002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{"openai_compact_supported": true}, // tier=2
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.4",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown")
|
||||
}
|
||||
|
||||
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported
|
||||
// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := int64(91002)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 71010,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
|
||||
},
|
||||
{
|
||||
ID: 71011,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{"openai_compact_supported": false},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.4",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
true,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error")
|
||||
require.Nil(t, selection)
|
||||
}
|
||||
|
||||
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown
|
||||
// 验证当没有"已知支持"账号时,compact 请求会回退到"未探测"账号。
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := int64(91003)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 71020,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{"openai_compact_supported": false}, // tier=0
|
||||
},
|
||||
{
|
||||
ID: 71021,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{}, // unknown -> tier=1
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.4",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available")
|
||||
}
|
||||
|
||||
// TestOpenAICompactSupportTier 验证 tier 分类逻辑。
|
||||
func TestOpenAICompactSupportTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want int
|
||||
}{
|
||||
{name: "nil", account: nil, want: 0},
|
||||
{name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0},
|
||||
{name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1},
|
||||
{name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2},
|
||||
{name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0},
|
||||
{name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2},
|
||||
{name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := openAICompactSupportTier(tt.account); got != tt.want {
|
||||
t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -289,6 +289,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -343,6 +344,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -384,6 +386,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
require.ErrorContains(t, err, "no available OpenAI accounts")
|
||||
require.Nil(t, selection)
|
||||
@@ -445,6 +448,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -486,7 +490,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
@@ -540,7 +544,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
@@ -616,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -662,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -740,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -788,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -857,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -900,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, selection)
|
||||
@@ -976,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
@@ -1014,7 +1025,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
|
||||
@@ -1218,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
|
||||
@@ -54,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
@@ -48,6 +49,8 @@ type codexTransformResult struct {
|
||||
const (
|
||||
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
|
||||
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
|
||||
codexSparkImageUnsupportedMarker = "<sub2api-codex-spark-image-unsupported>"
|
||||
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
|
||||
)
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||
@@ -151,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
if normalizeCodexTools(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
if normalizeCodexToolChoice(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
@@ -165,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
if applyInstructions(reqBody, isCodexCLI) {
|
||||
result.Modified = true
|
||||
}
|
||||
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||
if input, ok := reqBody["input"].([]any); ok {
|
||||
if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
|
||||
input = normalizedInput
|
||||
result.Modified = true
|
||||
}
|
||||
if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
|
||||
input = normalizedInput
|
||||
result.Modified = true
|
||||
}
|
||||
input = filterCodexInput(input, needsToolContinuation)
|
||||
reqBody["input"] = input
|
||||
result.Modified = true
|
||||
@@ -192,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexToolChoice(reqBody map[string]any) bool {
|
||||
choice, ok := reqBody["tool_choice"]
|
||||
if !ok || choice == nil {
|
||||
return false
|
||||
}
|
||||
choiceMap, ok := choice.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
|
||||
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
|
||||
return false
|
||||
}
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
|
||||
func codexToolsContainType(rawTools any, toolType string) bool {
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok || strings.TrimSpace(toolType) == "" {
|
||||
return false
|
||||
}
|
||||
for _, rawTool := range tools {
|
||||
tool, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
|
||||
if len(input) == 0 {
|
||||
return input, false
|
||||
}
|
||||
|
||||
modified := false
|
||||
normalized := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
normalized = append(normalized, item)
|
||||
continue
|
||||
}
|
||||
role, _ := m["role"].(string)
|
||||
if strings.TrimSpace(role) != "tool" {
|
||||
normalized = append(normalized, item)
|
||||
continue
|
||||
}
|
||||
|
||||
callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
// Responses does not accept role:"tool". If no call id is available,
|
||||
// preserve the text as a user message instead of sending invalid input.
|
||||
fallback := make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
fallback[key] = value
|
||||
}
|
||||
fallback["role"] = "user"
|
||||
delete(fallback, "tool_call_id")
|
||||
normalized = append(normalized, fallback)
|
||||
modified = true
|
||||
continue
|
||||
}
|
||||
|
||||
output := extractTextFromContent(m["content"])
|
||||
if output == "" {
|
||||
if value, ok := m["output"].(string); ok {
|
||||
output = value
|
||||
}
|
||||
}
|
||||
if output == "" && m["content"] != nil {
|
||||
if b, err := json.Marshal(m["content"]); err == nil {
|
||||
output = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
normalized = append(normalized, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
"output": output,
|
||||
})
|
||||
modified = true
|
||||
}
|
||||
if !modified {
|
||||
return input, false
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
func normalizeCodexMessageContentText(input []any) ([]any, bool) {
|
||||
if len(input) == 0 {
|
||||
return input, false
|
||||
}
|
||||
|
||||
modified := false
|
||||
normalized := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
|
||||
normalized = append(normalized, item)
|
||||
continue
|
||||
}
|
||||
parts, ok := m["content"].([]any)
|
||||
if !ok {
|
||||
normalized = append(normalized, item)
|
||||
continue
|
||||
}
|
||||
|
||||
var newItem map[string]any
|
||||
var newParts []any
|
||||
ensureItemCopy := func() {
|
||||
if newItem != nil {
|
||||
return
|
||||
}
|
||||
newItem = make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
newParts = make([]any, len(parts))
|
||||
copy(newParts, parts)
|
||||
}
|
||||
|
||||
for i, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text, hasText := part["text"]
|
||||
if !hasText {
|
||||
continue
|
||||
}
|
||||
if _, ok := text.(string); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ensureItemCopy()
|
||||
newPart := make(map[string]any, len(part))
|
||||
for key, value := range part {
|
||||
newPart[key] = value
|
||||
}
|
||||
newPart["text"] = stringifyCodexContentText(text)
|
||||
newParts[i] = newPart
|
||||
modified = true
|
||||
}
|
||||
|
||||
if newItem != nil {
|
||||
newItem["content"] = newParts
|
||||
normalized = append(normalized, newItem)
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, item)
|
||||
}
|
||||
if !modified {
|
||||
return input, false
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
func stringifyCodexContentText(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
if b, err := json.Marshal(v); err == nil {
|
||||
return string(b)
|
||||
}
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
@@ -244,6 +438,10 @@ func normalizeCodexModel(model string) string {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
func isCodexSparkModel(model string) bool {
|
||||
return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
|
||||
}
|
||||
|
||||
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
@@ -265,6 +463,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func hasOpenAIInputImage(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
|
||||
}
|
||||
|
||||
func hasOpenAIInputImageValue(value any) bool {
|
||||
switch v := value.(type) {
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if hasOpenAIInputImageValue(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
|
||||
return true
|
||||
}
|
||||
if _, ok := v["image_url"]; ok {
|
||||
return true
|
||||
}
|
||||
return hasOpenAIInputImageValue(v["content"])
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateCodexSparkInput(reqBody map[string]any, model string) error {
|
||||
if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
@@ -309,6 +541,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
|
||||
if len(reqBody) == 0 {
|
||||
return false
|
||||
}
|
||||
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
|
||||
return false
|
||||
}
|
||||
|
||||
tool := map[string]any{
|
||||
"type": "image_generation",
|
||||
@@ -344,6 +579,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
|
||||
if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
|
||||
return false
|
||||
}
|
||||
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
|
||||
return false
|
||||
}
|
||||
|
||||
existing, _ := reqBody["instructions"].(string)
|
||||
if strings.Contains(existing, codexImageGenerationBridgeMarker) {
|
||||
@@ -360,6 +598,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
|
||||
if len(reqBody) == 0 {
|
||||
return false
|
||||
}
|
||||
existing, _ := reqBody["instructions"].(string)
|
||||
if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
|
||||
return false
|
||||
}
|
||||
existing = strings.TrimRight(existing, " \t\r\n")
|
||||
if strings.TrimSpace(existing) == "" {
|
||||
reqBody["instructions"] = codexSparkImageUnsupportedText
|
||||
return true
|
||||
}
|
||||
reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
|
||||
return true
|
||||
}
|
||||
|
||||
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
|
||||
if !hasOpenAIImageGenerationTool(reqBody) {
|
||||
return nil
|
||||
@@ -658,12 +913,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
}
|
||||
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
ensureCopy()
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
|
||||
if codexInputItemRequiresName(typ) {
|
||||
if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
|
||||
name := firstNonEmptyString(m["tool_name"])
|
||||
if name == "" {
|
||||
if function, ok := m["function"].(map[string]any); ok {
|
||||
name = firstNonEmptyString(function["name"])
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
name = "tool"
|
||||
}
|
||||
ensureCopy()
|
||||
newItem["name"] = name
|
||||
}
|
||||
}
|
||||
|
||||
if !preserveReferences {
|
||||
ensureCopy()
|
||||
delete(newItem, "id")
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
@@ -672,10 +945,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
func isCodexToolCallItemType(typ string) bool {
|
||||
if typ == "" {
|
||||
switch typ {
|
||||
case "function_call",
|
||||
"tool_call",
|
||||
"local_shell_call",
|
||||
"tool_search_call",
|
||||
"custom_tool_call",
|
||||
"mcp_tool_call",
|
||||
"function_call_output",
|
||||
"mcp_tool_call_output",
|
||||
"custom_tool_call_output",
|
||||
"tool_search_output":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func codexInputItemRequiresName(typ string) bool {
|
||||
switch strings.TrimSpace(typ) {
|
||||
case "function_call", "custom_tool_call", "mcp_tool_call":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
|
||||
}
|
||||
|
||||
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
|
||||
@@ -92,6 +92,235 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "tool_search_output", first["type"])
|
||||
require.Equal(t, "fc1", first["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
|
||||
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fccustom", first["call_id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fcmcp", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
|
||||
map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ig_123", first["id"])
|
||||
_, hasCallID := first["call_id"]
|
||||
require.False(t, hasCallID)
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasCallID = second["call_id"]
|
||||
require.False(t, hasCallID)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "ok",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function_call_output", item["type"])
|
||||
require.Equal(t, "fc1", item["call_id"])
|
||||
require.Equal(t, "ok", item["output"])
|
||||
_, hasRole := item["role"]
|
||||
require.False(t, hasRole)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": []any{"a", "b"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := item["content"].([]any)
|
||||
require.True(t, ok)
|
||||
part, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, `["a","b"]`, part["text"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{"type": "custom"},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
require.Equal(t, "auto", reqBody["tool_choice"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "custom", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{"type": "custom"},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
choice, ok := reqBody["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "custom", choice["type"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"input": []any{
|
||||
map[string]any{"type": "message", "role": "user", "content": "run tool"},
|
||||
map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
item, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function_call", item["type"])
|
||||
require.Equal(t, "tool", item["name"])
|
||||
require.Equal(t, "fc1", item["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"input": []any{
|
||||
map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "shell", item["name"])
|
||||
require.Equal(t, "fc1", item["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "mcp_tool_call",
|
||||
"call_id": "call_abc",
|
||||
"name": "remote_tool",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "mcp_tool_call", item["type"])
|
||||
require.Equal(t, "remote_tool", item["name"])
|
||||
require.Equal(t, "fcabc", item["call_id"])
|
||||
}
|
||||
|
||||
func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
|
||||
for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
|
||||
require.True(t, codexInputItemRequiresName(typ), typ)
|
||||
require.True(t, isCodexToolCallItemType(typ), typ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||
|
||||
@@ -261,6 +490,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
|
||||
require.Equal(t, "png", tool["output_format"])
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"input": "draw a cat",
|
||||
}
|
||||
|
||||
modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
|
||||
require.False(t, modified)
|
||||
require.NotContains(t, reqBody, "tools")
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
@@ -306,6 +546,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *
|
||||
|
||||
func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"instructions": "existing instructions",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "image_generation", "output_format": "png"},
|
||||
@@ -325,6 +566,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin
|
||||
require.False(t, modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"instructions": "existing instructions",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "image_generation", "output_format": "png"},
|
||||
},
|
||||
}
|
||||
|
||||
modified := applyCodexImageGenerationBridgeInstructions(reqBody)
|
||||
require.False(t, modified)
|
||||
require.Equal(t, "existing instructions", reqBody["instructions"])
|
||||
}
|
||||
|
||||
func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"instructions": "existing instructions",
|
||||
@@ -338,6 +593,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te
|
||||
require.Equal(t, "existing instructions", reqBody["instructions"])
|
||||
}
|
||||
|
||||
func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "describe"},
|
||||
map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "does not support image input")
|
||||
}
|
||||
|
||||
func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
"instructions": "existing instructions",
|
||||
"input": "hello",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true, false)
|
||||
require.True(t, result.Modified)
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, instructions, "existing instructions")
|
||||
require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
|
||||
require.Contains(t, instructions, "does not support image generation")
|
||||
require.Contains(t, instructions, "switch to a non-Spark Codex model")
|
||||
require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"instructions": "existing instructions",
|
||||
"input": "hello",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-image-2",
|
||||
|
||||
135
backend/internal/service/openai_compact_model_mapping_test.go
Normal file
135
backend/internal/service/openai_compact_model_mapping_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-5.4", result.Model)
|
||||
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-5.4", result.Model)
|
||||
require.Equal(t, "gpt-5.4", result.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "openai-oauth-pass",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
|
||||
},
|
||||
Extra: map[string]any{"openai_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-5.4", result.Model)
|
||||
require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String())
|
||||
}
|
||||
120
backend/internal/service/openai_compact_probe.go
Normal file
120
backend/internal/service/openai_compact_probe.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccountTestModeDefault drives the standard /responses connection test.
|
||||
AccountTestModeDefault = "default"
|
||||
// AccountTestModeCompact drives the /responses/compact compact-probe test.
|
||||
AccountTestModeCompact = "compact"
|
||||
)
|
||||
|
||||
func normalizeAccountTestMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case AccountTestModeCompact:
|
||||
return AccountTestModeCompact
|
||||
default:
|
||||
return AccountTestModeDefault
|
||||
}
|
||||
}
|
||||
|
||||
func createOpenAICompactProbePayload(model string) map[string]any {
|
||||
return map[string]any{
|
||||
"model": strings.TrimSpace(model),
|
||||
"instructions": "You are a helpful coding assistant.",
|
||||
"input": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "Respond with OK.",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool {
|
||||
switch status {
|
||||
case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
|
||||
return true
|
||||
case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity:
|
||||
lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body)))
|
||||
if strings.Contains(lower, "compact") {
|
||||
for _, keyword := range []string{
|
||||
"unsupported",
|
||||
"not support",
|
||||
"does not support",
|
||||
"not available",
|
||||
"disabled",
|
||||
} {
|
||||
if strings.Contains(lower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any {
|
||||
updates := map[string]any{
|
||||
"openai_compact_checked_at": now.Format(time.RFC3339),
|
||||
"openai_compact_last_status": nil,
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
updates["openai_compact_last_status"] = resp.StatusCode
|
||||
}
|
||||
|
||||
switch {
|
||||
case probeErr != nil:
|
||||
updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048)
|
||||
case resp == nil:
|
||||
updates["openai_compact_last_error"] = "compact probe failed"
|
||||
default:
|
||||
errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
if errMsg == "" && len(body) > 0 {
|
||||
errMsg = strings.TrimSpace(string(body))
|
||||
}
|
||||
if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
|
||||
errMsg = "HTTP " + strconv.Itoa(resp.StatusCode)
|
||||
}
|
||||
errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048)
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
updates["openai_compact_supported"] = true
|
||||
updates["openai_compact_last_error"] = ""
|
||||
} else {
|
||||
if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) {
|
||||
updates["openai_compact_supported"] = false
|
||||
}
|
||||
updates["openai_compact_last_error"] = errMsg
|
||||
}
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(more) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(more))
|
||||
for key, value := range base {
|
||||
out[key] = value
|
||||
}
|
||||
for key, value := range more {
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func compactProbeSessionID(accountID int64) string {
|
||||
if accountID <= 0 {
|
||||
return "probe_compact"
|
||||
}
|
||||
return "probe_compact_" + strconv.FormatInt(accountID, 10)
|
||||
}
|
||||
122
backend/internal/service/openai_compact_probe_test.go
Normal file
122
backend/internal/service/openai_compact_probe_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeAccountTestMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{input: "", want: AccountTestModeDefault},
|
||||
{input: "default", want: AccountTestModeDefault},
|
||||
{input: " compact ", want: AccountTestModeCompact},
|
||||
{input: "COMPACT", want: AccountTestModeCompact},
|
||||
{input: "unknown", want: AccountTestModeDefault},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := normalizeAccountTestMode(tt.input); got != tt.want {
|
||||
t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now)
|
||||
|
||||
if got := updates["openai_compact_supported"]; got != true {
|
||||
t.Fatalf("openai_compact_supported = %v, want true", got)
|
||||
}
|
||||
if got := updates["openai_compact_last_status"]; got != http.StatusOK {
|
||||
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK)
|
||||
}
|
||||
if got := updates["openai_compact_last_error"]; got != "" {
|
||||
t.Fatalf("openai_compact_last_error = %v, want empty string", got)
|
||||
}
|
||||
if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) {
|
||||
t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
body := []byte(`404 page not found`)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now)
|
||||
|
||||
if got := updates["openai_compact_supported"]; got != false {
|
||||
t.Fatalf("openai_compact_supported = %v, want false", got)
|
||||
}
|
||||
if got := updates["openai_compact_last_status"]; got != http.StatusNotFound {
|
||||
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now)
|
||||
|
||||
if _, exists := updates["openai_compact_supported"]; exists {
|
||||
t.Fatalf("did not expect openai_compact_supported for 502 response")
|
||||
}
|
||||
if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway {
|
||||
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now)
|
||||
|
||||
if _, exists := updates["openai_compact_supported"]; exists {
|
||||
t.Fatalf("did not expect openai_compact_supported for request error")
|
||||
}
|
||||
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
|
||||
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
|
||||
}
|
||||
if got := updates["openai_compact_last_error"]; got == "" {
|
||||
t.Fatalf("expected openai_compact_last_error to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now)
|
||||
|
||||
if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
|
||||
t.Fatalf("openai_compact_last_status = %v, want nil key", got)
|
||||
}
|
||||
if got := updates["openai_compact_last_error"]; got != "compact probe failed" {
|
||||
t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now)
|
||||
|
||||
if _, exists := updates["openai_compact_supported"]; exists {
|
||||
t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics")
|
||||
}
|
||||
if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest {
|
||||
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) {
|
||||
now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
|
||||
updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now)
|
||||
|
||||
if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable {
|
||||
t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable)
|
||||
}
|
||||
if got := updates["openai_compact_last_error"]; got != "HTTP 503" {
|
||||
t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
|
||||
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||
func (c cancelReadCloser) Close() error { return nil }
|
||||
|
||||
type errReadCloser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
|
||||
func (r errReadCloser) Close() error { return nil }
|
||||
|
||||
type failingGinWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
@@ -1003,6 +1010,190 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: errReadCloser{err: io.ErrUnexpectedEOF},
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.in_progress",
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.in_progress",
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 1,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
}
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
require.Contains(t, rec.Body.String(), "response.completed")
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.False(t, errors.As(err, &failoverErr))
|
||||
require.True(t, c.Writer.Written())
|
||||
require.Contains(t, rec.Body.String(), "response.failed")
|
||||
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1072,7 +1263,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
@@ -1104,16 +1295,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1139,7 +1366,42 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 2, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
@@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" {
|
||||
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
func isExplicitCodexModel(model string) bool {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
parts := strings.Split(model, "/")
|
||||
model = parts[len(parts)-1]
|
||||
}
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if getNormalizedCodexModel(model) != "" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(model, "-openai-compact") {
|
||||
base := strings.TrimSuffix(model, "-openai-compact")
|
||||
return getNormalizedCodexModel(base) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveOpenAICompactForwardModel determines the compact-only upstream model
|
||||
// for /responses/compact requests. It never affects normal /responses traffic.
|
||||
// When no compact-specific mapping matches, the input model is returned as-is.
|
||||
func resolveOpenAICompactForwardModel(account *Account, model string) string {
|
||||
trimmedModel := strings.TrimSpace(model)
|
||||
if trimmedModel == "" || account == nil {
|
||||
return trimmedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel)
|
||||
if !matched {
|
||||
return trimmedModel
|
||||
}
|
||||
if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" {
|
||||
return trimmedMapped
|
||||
}
|
||||
return trimmedModel
|
||||
}
|
||||
|
||||
@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
requestedModel: "claude-opus-4-6",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "preserves explicit gpt-5.4 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "preserves exact passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "preserves codex spark instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.3-codex-spark",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.3-codex-spark",
|
||||
},
|
||||
{
|
||||
name: "preserves gpt-5.5 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.5",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves openai namespaced gpt-5.5 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "openai/gpt-5.5",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "openai/gpt-5.5",
|
||||
},
|
||||
{
|
||||
name: "preserves compact gpt-5.5 instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.5-openai-compact",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.5-openai-compact",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -85,6 +130,74 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAICompactForwardModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "nil account keeps original model",
|
||||
account: nil,
|
||||
model: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "missing compact mapping keeps original model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
model: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "exact compact mapping overrides model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4-openai-compact",
|
||||
},
|
||||
},
|
||||
},
|
||||
model: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4-openai-compact",
|
||||
},
|
||||
{
|
||||
name: "wildcard compact mapping overrides model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.*": "gpt-5-openai-compact",
|
||||
},
|
||||
},
|
||||
},
|
||||
model: "gpt-5.4",
|
||||
expectedModel: "gpt-5-openai-compact",
|
||||
},
|
||||
{
|
||||
name: "passthrough compact mapping remains unchanged",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"compact_model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
model: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel {
|
||||
t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexModel(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
|
||||
@@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
|
||||
|
||||
@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
|
||||
}
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "function_call_output" || itemType == "item_reference" {
|
||||
if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||
{name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
|
||||
{name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
|
||||
{name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
|
||||
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
@@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
|
||||
@@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
|
||||
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection)
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
|
||||
}
|
||||
@@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
|
||||
@@ -3800,6 +3800,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
previousResponseID string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requireCompact bool,
|
||||
) (*AccountSelectionResult, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
@@ -3840,11 +3841,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil, nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
|
||||
if requireCompact && openAICompactSupportTier(account) == 0 {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -268,6 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
|
||||
switch action {
|
||||
case redeemActionSkipCompleted:
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
return err
|
||||
}
|
||||
// Code already created and redeemed — just mark completed
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
case redeemActionCreate:
|
||||
@@ -281,6 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
||||
return fmt.Errorf("redeem balance: %w", err)
|
||||
}
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
}
|
||||
|
||||
@@ -358,6 +365,142 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
|
||||
return c > 0
|
||||
}
|
||||
|
||||
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
|
||||
return nil
|
||||
}
|
||||
if s.affiliateService == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
|
||||
})
|
||||
return fmt.Errorf("begin affiliate rebate tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
|
||||
if err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("claim affiliate rebate audit: %w", err)
|
||||
}
|
||||
if !claimed {
|
||||
return nil
|
||||
}
|
||||
|
||||
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
|
||||
if err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("accrue affiliate rebate: %w", err)
|
||||
}
|
||||
|
||||
if rebateAmount <= 0 {
|
||||
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
|
||||
"baseAmount": o.Amount,
|
||||
"reason": "no inviter bound or rebate amount <= 0",
|
||||
}); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
|
||||
})
|
||||
return fmt.Errorf("commit affiliate rebate tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
|
||||
"baseAmount": o.Amount,
|
||||
"rebateAmount": rebateAmount,
|
||||
}); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("update affiliate rebate applied audit: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
|
||||
})
|
||||
return fmt.Errorf("commit affiliate rebate tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
|
||||
if client == nil {
|
||||
return false, errors.New("nil payment client")
|
||||
}
|
||||
oid := strconv.FormatInt(orderID, 10)
|
||||
detail, _ := json.Marshal(map[string]any{
|
||||
"baseAmount": baseAmount,
|
||||
"status": "reserved",
|
||||
})
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
|
||||
SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM payment_audit_logs
|
||||
WHERE order_id = $1::text
|
||||
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
|
||||
)
|
||||
ON CONFLICT (order_id, action) DO NOTHING
|
||||
RETURNING id`, oid, string(detail))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
var claimID int64
|
||||
if err := rows.Scan(&claimID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
|
||||
if client == nil {
|
||||
return errors.New("nil payment client")
|
||||
}
|
||||
oid := strconv.FormatInt(orderID, 10)
|
||||
detailJSON, _ := json.Marshal(detail)
|
||||
updated, err := client.PaymentAuditLog.Update().
|
||||
Where(
|
||||
paymentauditlog.OrderIDEQ(oid),
|
||||
paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
|
||||
).
|
||||
SetAction(action).
|
||||
SetDetail(string(detailJSON)).
|
||||
SetOperator("system").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if updated == 0 {
|
||||
return errors.New("affiliate rebate claim log not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
|
||||
now := time.Now()
|
||||
r := psErrMsg(cause)
|
||||
|
||||
@@ -170,21 +170,22 @@ type TopUserStat struct {
|
||||
// --- Service ---
|
||||
|
||||
type PaymentService struct {
|
||||
providerMu sync.Mutex
|
||||
providersLoaded bool
|
||||
entClient *dbent.Client
|
||||
registry *payment.Registry
|
||||
loadBalancer payment.LoadBalancer
|
||||
redeemService *RedeemService
|
||||
subscriptionSvc *SubscriptionService
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
resumeService *PaymentResumeService
|
||||
providerMu sync.Mutex
|
||||
providersLoaded bool
|
||||
entClient *dbent.Client
|
||||
registry *payment.Registry
|
||||
loadBalancer payment.LoadBalancer
|
||||
redeemService *RedeemService
|
||||
subscriptionSvc *SubscriptionService
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
resumeService *PaymentResumeService
|
||||
affiliateService *AffiliateService
|
||||
}
|
||||
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
|
||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
|
||||
svc.resumeService = psNewPaymentResumeService(configService)
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
|
||||
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
|
||||
// 返回 nil 表示无法从响应头中确定重置时间
|
||||
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
@@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
return calculateOpenAI429ResetTime(headers)
|
||||
}
|
||||
|
||||
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
|
||||
type anthropic429Result struct {
|
||||
resetAt time.Time // The correct reset time to use for SetRateLimited
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -453,6 +454,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyChannelMonitorEnabled,
|
||||
SettingKeyChannelMonitorDefaultIntervalSeconds,
|
||||
SettingKeyAvailableChannelsEnabled,
|
||||
SettingKeyAffiliateEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -540,6 +542,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
|
||||
|
||||
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
|
||||
|
||||
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -686,6 +690,7 @@ type PublicSettingsInjectionPayload struct {
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
||||
@@ -738,6 +743,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1167,6 +1173,26 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
|
||||
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
|
||||
if settings.AffiliateRebateFreezeHours < 0 {
|
||||
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
|
||||
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
|
||||
if settings.AffiliateRebateDurationDays < 0 {
|
||||
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
|
||||
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
|
||||
}
|
||||
updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
|
||||
if settings.AffiliateRebatePerInviteeCap < 0 {
|
||||
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
@@ -1202,6 +1228,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
// Available channels feature switch
|
||||
updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled)
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
|
||||
|
||||
// Claude Code version check
|
||||
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
||||
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
|
||||
@@ -1477,6 +1506,78 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
|
||||
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。
|
||||
// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错,
|
||||
// 调用方只关心一个可用的数值。
|
||||
func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
|
||||
if err != nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
|
||||
if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
return clampAffiliateRebateRate(rate)
|
||||
}
|
||||
|
||||
// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
|
||||
// 返回 0 表示不冻结(向后兼容)。
|
||||
func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
|
||||
if err != nil {
|
||||
return AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
hours, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || hours < 0 {
|
||||
return AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if hours > AffiliateRebateFreezeHoursMax {
|
||||
return AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
return hours
|
||||
}
|
||||
|
||||
// GetAffiliateRebateDurationDays 返回返利有效期(天)。
|
||||
// 返回 0 表示永久有效。
|
||||
func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
|
||||
if err != nil {
|
||||
return AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
days, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || days < 0 {
|
||||
return AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if days > AffiliateRebateDurationDaysMax {
|
||||
return AffiliateRebateDurationDaysMax
|
||||
}
|
||||
return days
|
||||
}
|
||||
|
||||
// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
|
||||
// 返回 0 表示无上限。
|
||||
func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
|
||||
if err != nil {
|
||||
return AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
|
||||
if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
|
||||
return AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
return cap
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
@@ -1719,6 +1820,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
|
||||
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
|
||||
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
|
||||
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
|
||||
SettingKeyDefaultUserRPMLimit: "0",
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
@@ -1767,6 +1872,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// Available channels feature (default disabled; opt-in)
|
||||
SettingKeyAvailableChannelsEnabled: "false",
|
||||
|
||||
// Affiliate (邀请返利) feature (default disabled; opt-in)
|
||||
SettingKeyAffiliateEnabled: "false",
|
||||
|
||||
// Claude Code version check (default: empty = disabled)
|
||||
SettingKeyMinClaudeCodeVersion: "",
|
||||
SettingKeyMaxClaudeCodeVersion: "",
|
||||
@@ -1846,6 +1954,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
} else {
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
|
||||
result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
|
||||
} else {
|
||||
result.AffiliateRebateRate = AffiliateRebateRateDefault
|
||||
}
|
||||
if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
|
||||
if freezeHours > AffiliateRebateFreezeHoursMax {
|
||||
freezeHours = AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
result.AffiliateRebateFreezeHours = freezeHours
|
||||
}
|
||||
if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
|
||||
if durationDays > AffiliateRebateDurationDaysMax {
|
||||
durationDays = AffiliateRebateDurationDaysMax
|
||||
}
|
||||
result.AffiliateRebateDurationDays = durationDays
|
||||
}
|
||||
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
|
||||
result.AffiliateRebatePerInviteeCap = perInviteeCap
|
||||
}
|
||||
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
@@ -2082,6 +2210,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
// Available channels feature (default: disabled; strict true)
|
||||
result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true"
|
||||
|
||||
// Affiliate (邀请返利) feature (default: disabled; strict true)
|
||||
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
|
||||
|
||||
// Claude Code version check
|
||||
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
|
||||
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
|
||||
@@ -2130,6 +2261,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
return result
|
||||
}
|
||||
|
||||
func clampAffiliateRebateRate(value float64) float64 {
|
||||
if math.IsNaN(value) || math.IsInf(value, 0) {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
if value < AffiliateRebateRateMin {
|
||||
return AffiliateRebateRateMin
|
||||
}
|
||||
if value > AffiliateRebateRateMax {
|
||||
return AffiliateRebateRateMax
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func isFalseSettingValue(value string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
|
||||
@@ -104,10 +104,15 @@ type SystemSettings struct {
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
AffiliateEnabled bool
|
||||
AffiliateRebateRate float64
|
||||
AffiliateRebateFreezeHours int
|
||||
AffiliateRebateDurationDays int
|
||||
AffiliateRebatePerInviteeCap float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -224,6 +229,9 @@ type PublicSettings struct {
|
||||
|
||||
// Available Channels feature (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature toggle
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
type WeChatConnectOAuthConfig struct {
|
||||
|
||||
@@ -486,6 +486,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
NewAffiliateService,
|
||||
ProvidePaymentConfigService,
|
||||
NewPaymentService,
|
||||
ProvidePaymentOrderExpiryService,
|
||||
|
||||
20
backend/migrations/130_add_user_affiliates.sql
Normal file
20
backend/migrations/130_add_user_affiliates.sql
Normal file
@@ -0,0 +1,20 @@
|
||||
CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||
user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
|
||||
aff_code VARCHAR(32) NOT NULL UNIQUE,
|
||||
inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
|
||||
aff_count INTEGER NOT NULL DEFAULT 0,
|
||||
aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
|
||||
|
||||
COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
|
||||
COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
|
||||
COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
|
||||
COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
|
||||
COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
|
||||
COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';
|
||||
58
backend/migrations/131_affiliate_rebate_hardening.sql
Normal file
58
backend/migrations/131_affiliate_rebate_hardening.sql
Normal file
@@ -0,0 +1,58 @@
|
||||
-- 1) Normalize historical affiliate rebate rate values.
|
||||
-- Legacy compatibility treated 0<x<=1 as fractional inputs (e.g. 0.2 => 20%).
|
||||
-- We now use pure percentage semantics, so convert persisted fractional values once.
|
||||
UPDATE settings
|
||||
SET value = to_char((value::numeric * 100), 'FM999999990.########'),
|
||||
updated_at = NOW()
|
||||
WHERE key = 'affiliate_rebate_rate'
|
||||
AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
|
||||
AND value::numeric > 0
|
||||
AND value::numeric <= 1;
|
||||
|
||||
-- 2) Affiliate ledger for accrual/transfer traceability.
|
||||
CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
action VARCHAR(32) NOT NULL,
|
||||
amount DECIMAL(20,8) NOT NULL,
|
||||
source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
|
||||
|
||||
COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
|
||||
COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
|
||||
|
||||
-- 3) Enforce idempotency at DB layer for payment audit actions.
|
||||
WITH ranked AS (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
|
||||
FROM payment_audit_logs
|
||||
)
|
||||
DELETE FROM payment_audit_logs p
|
||||
USING ranked r
|
||||
WHERE p.id = r.id
|
||||
AND r.rn > 1;
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
|
||||
ON payment_audit_logs(order_id, action);
|
||||
|
||||
-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
|
||||
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
|
||||
SELECT po.id::text,
|
||||
'AFFILIATE_REBATE_SKIPPED',
|
||||
'{"reason":"baseline before affiliate rebate idempotency rollout"}',
|
||||
'system',
|
||||
NOW()
|
||||
FROM payment_orders po
|
||||
WHERE po.order_type = 'balance'
|
||||
AND po.status = 'COMPLETED'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM payment_audit_logs pal
|
||||
WHERE pal.order_id = po.id::text
|
||||
AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
|
||||
);
|
||||
16
backend/migrations/132_affiliate_custom_settings.sql
Normal file
16
backend/migrations/132_affiliate_custom_settings.sql
Normal file
@@ -0,0 +1,16 @@
|
||||
-- 邀请返利:用户专属配置增强
|
||||
-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例(百分比,NULL 表示沿用全局比例)
|
||||
-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选)
|
||||
|
||||
ALTER TABLE user_affiliates
|
||||
ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2);
|
||||
|
||||
ALTER TABLE user_affiliates
|
||||
ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings
|
||||
ON user_affiliates (updated_at)
|
||||
WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL;
|
||||
|
||||
COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100,NULL 表示沿用全局)';
|
||||
COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)';
|
||||
17
backend/migrations/133_affiliate_rebate_freeze.sql
Normal file
17
backend/migrations/133_affiliate_rebate_freeze.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
|
||||
ALTER TABLE user_affiliates
|
||||
ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
|
||||
|
||||
COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
|
||||
|
||||
-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
|
||||
-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
|
||||
|
||||
COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
|
||||
|
||||
-- 3) Partial index for efficient thaw queries (only rows still frozen).
|
||||
CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
|
||||
ON user_affiliate_ledger (user_id, frozen_until)
|
||||
WHERE frozen_until IS NOT NULL;
|
||||
@@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('posts affiliate code when completing linuxdo oauth registration', async () => {
|
||||
const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
|
||||
|
||||
await completeLinuxDoOAuthRegistration(
|
||||
'invite-code',
|
||||
{
|
||||
adoptDisplayName: true,
|
||||
adoptAvatar: false
|
||||
},
|
||||
' AFF123 '
|
||||
)
|
||||
|
||||
expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
|
||||
invitation_code: 'invite-code',
|
||||
aff_code: 'AFF123',
|
||||
adopt_display_name: true,
|
||||
adopt_avatar: false
|
||||
})
|
||||
})
|
||||
|
||||
it('posts oidc invitation completion with adoption decisions', async () => {
|
||||
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
|
||||
|
||||
@@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('posts affiliate code when creating pending wechat oauth account', async () => {
|
||||
const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
|
||||
|
||||
await createPendingWeChatOAuthAccount(
|
||||
'invite-code',
|
||||
{
|
||||
adoptDisplayName: false,
|
||||
adoptAvatar: true
|
||||
},
|
||||
'WXAFF'
|
||||
)
|
||||
|
||||
expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
|
||||
invitation_code: 'invite-code',
|
||||
aff_code: 'WXAFF',
|
||||
adopt_display_name: false,
|
||||
adopt_avatar: true
|
||||
})
|
||||
})
|
||||
|
||||
it('classifies oauth completion results as login or bind', async () => {
|
||||
const { getOAuthCompletionKind } = await import('@/api/auth')
|
||||
|
||||
|
||||
108
frontend/src/api/admin/affiliates.ts
Normal file
108
frontend/src/api/admin/affiliates.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Admin Affiliate API endpoints
|
||||
* Manage per-user affiliate (邀请返利) configurations:
|
||||
* exclusive invite codes (overrides aff_code) and exclusive rebate rates.
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client'
|
||||
import type { PaginatedResponse } from '@/types'
|
||||
|
||||
export interface AffiliateAdminEntry {
|
||||
user_id: number
|
||||
email: string
|
||||
username: string
|
||||
aff_code: string
|
||||
aff_code_custom: boolean
|
||||
aff_rebate_rate_percent?: number | null
|
||||
aff_count: number
|
||||
}
|
||||
|
||||
export interface ListAffiliateUsersParams {
|
||||
page?: number
|
||||
page_size?: number
|
||||
search?: string
|
||||
}
|
||||
|
||||
export interface UpdateAffiliateUserRequest {
|
||||
aff_code?: string
|
||||
aff_rebate_rate_percent?: number | null
|
||||
/** Set true to explicitly clear the per-user rate (sets it to NULL). */
|
||||
clear_rebate_rate?: boolean
|
||||
}
|
||||
|
||||
export interface BatchSetRateRequest {
|
||||
user_ids: number[]
|
||||
aff_rebate_rate_percent?: number | null
|
||||
/** Set true to clear rates instead of setting. */
|
||||
clear?: boolean
|
||||
}
|
||||
|
||||
export interface SimpleUser {
|
||||
id: number
|
||||
email: string
|
||||
username: string
|
||||
}
|
||||
|
||||
export async function listUsers(
|
||||
params: ListAffiliateUsersParams = {},
|
||||
): Promise<PaginatedResponse<AffiliateAdminEntry>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<AffiliateAdminEntry>>(
|
||||
'/admin/affiliates/users',
|
||||
{
|
||||
params: {
|
||||
page: params.page ?? 1,
|
||||
page_size: params.page_size ?? 20,
|
||||
search: params.search ?? '',
|
||||
},
|
||||
},
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function lookupUsers(q: string): Promise<SimpleUser[]> {
|
||||
const { data } = await apiClient.get<SimpleUser[]>(
|
||||
'/admin/affiliates/users/lookup',
|
||||
{ params: { q } },
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateUserSettings(
|
||||
userId: number,
|
||||
payload: UpdateAffiliateUserRequest,
|
||||
): Promise<{ user_id: number }> {
|
||||
const { data } = await apiClient.put<{ user_id: number }>(
|
||||
`/admin/affiliates/users/${userId}`,
|
||||
payload,
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function clearUserSettings(
|
||||
userId: number,
|
||||
): Promise<{ user_id: number }> {
|
||||
const { data } = await apiClient.delete<{ user_id: number }>(
|
||||
`/admin/affiliates/users/${userId}`,
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function batchSetRate(
|
||||
payload: BatchSetRateRequest,
|
||||
): Promise<{ affected: number }> {
|
||||
const { data } = await apiClient.post<{ affected: number }>(
|
||||
'/admin/affiliates/users/batch-rate',
|
||||
payload,
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export const affiliatesAPI = {
|
||||
listUsers,
|
||||
lookupUsers,
|
||||
updateUserSettings,
|
||||
clearUserSettings,
|
||||
batchSetRate,
|
||||
}
|
||||
|
||||
export default affiliatesAPI
|
||||
@@ -29,6 +29,7 @@ import channelsAPI from './channels'
|
||||
import channelMonitorAPI from './channelMonitor'
|
||||
import channelMonitorTemplateAPI from './channelMonitorTemplate'
|
||||
import adminPaymentAPI from './payment'
|
||||
import affiliatesAPI from './affiliates'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@@ -59,7 +60,8 @@ export const adminAPI = {
|
||||
channels: channelsAPI,
|
||||
channelMonitor: channelMonitorAPI,
|
||||
channelMonitorTemplate: channelMonitorTemplateAPI,
|
||||
payment: adminPaymentAPI
|
||||
payment: adminPaymentAPI,
|
||||
affiliates: affiliatesAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@@ -88,7 +90,8 @@ export {
|
||||
channelsAPI,
|
||||
channelMonitorAPI,
|
||||
channelMonitorTemplateAPI,
|
||||
adminPaymentAPI
|
||||
adminPaymentAPI,
|
||||
affiliatesAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
@@ -308,6 +308,10 @@ export interface SystemSettings {
|
||||
totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
|
||||
// Default settings
|
||||
default_balance: number;
|
||||
affiliate_rebate_rate: number;
|
||||
affiliate_rebate_freeze_hours: number;
|
||||
affiliate_rebate_duration_days: number;
|
||||
affiliate_rebate_per_invitee_cap: number;
|
||||
default_concurrency: number;
|
||||
default_user_rpm_limit: number;
|
||||
default_subscriptions: DefaultSubscriptionSetting[];
|
||||
@@ -477,6 +481,9 @@ export interface SystemSettings {
|
||||
|
||||
// Available Channels feature switch
|
||||
available_channels_enabled: boolean;
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled: boolean;
|
||||
}
|
||||
|
||||
export interface UpdateSettingsRequest {
|
||||
@@ -489,6 +496,10 @@ export interface UpdateSettingsRequest {
|
||||
invitation_code_enabled?: boolean;
|
||||
totp_enabled?: boolean; // TOTP 双因素认证
|
||||
default_balance?: number;
|
||||
affiliate_rebate_rate?: number;
|
||||
affiliate_rebate_freeze_hours?: number;
|
||||
affiliate_rebate_duration_days?: number;
|
||||
affiliate_rebate_per_invitee_cap?: number;
|
||||
default_concurrency?: number;
|
||||
default_user_rpm_limit?: number;
|
||||
default_subscriptions?: DefaultSubscriptionSetting[];
|
||||
@@ -634,6 +645,9 @@ export interface UpdateSettingsRequest {
|
||||
|
||||
// Available Channels feature switch
|
||||
available_channels_enabled?: boolean;
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user