mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Compare commits
96 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c129825f9b | ||
|
|
ff50b8b6ea | ||
|
|
4cbf518f0a | ||
|
|
dc09b367dc | ||
|
|
11fe29223d | ||
|
|
0b84d12dbb | ||
|
|
76e2503d5e | ||
|
|
3ab40269b4 | ||
|
|
650ddb2e39 | ||
|
|
0a914e034c | ||
|
|
6a41cf6a51 | ||
|
|
23555be380 | ||
|
|
47fb38bca1 | ||
|
|
72d5ee4cd1 | ||
|
|
b2bdba78dd | ||
|
|
3930bebaf9 | ||
|
|
e736de1ed9 | ||
|
|
57099a6af6 | ||
|
|
adf01ac880 | ||
|
|
4d145300c3 | ||
|
|
4e4cc80971 | ||
|
|
48912014a1 | ||
|
|
9d801595c9 | ||
|
|
9c448f89a8 | ||
|
|
73b872998e | ||
|
|
094e1171ef | ||
|
|
733627cf9d | ||
|
|
f084d30d65 | ||
|
|
3953dc9ce4 | ||
|
|
8ad099baa6 | ||
|
|
8bf2a7b88a | ||
|
|
40feb86ba4 | ||
|
|
f972a2faf2 | ||
|
|
55a7fa1e07 | ||
|
|
5e54d492be | ||
|
|
8d6d31545f | ||
|
|
17ced6b73a | ||
|
|
7f8f3fe0dd | ||
|
|
46f06b2498 | ||
|
|
7ce5b83215 | ||
|
|
27cad10d30 | ||
|
|
ff6fa0203d | ||
|
|
f7c13af11f | ||
|
|
28dc34b6a3 | ||
|
|
4d676dddd1 | ||
|
|
93d91e20b9 | ||
|
|
63ef23108c | ||
|
|
d78478e866 | ||
|
|
bf43fb4e38 | ||
|
|
a16c66500f | ||
|
|
4b6954f9f0 | ||
|
|
da4b078df2 | ||
|
|
7452fad820 | ||
|
|
4c474616b9 | ||
|
|
6327573534 | ||
|
|
04b2866f65 | ||
|
|
b0a2252ed1 | ||
|
|
30f55a1f72 | ||
|
|
3d4ca5e8d1 | ||
|
|
0537a490f0 | ||
|
|
ca5d029e7c | ||
|
|
1eca03432a | ||
|
|
53b24bc2d8 | ||
|
|
a161f9d045 | ||
|
|
c5a1a82223 | ||
|
|
2ab6b34fd1 | ||
|
|
764afbe37a | ||
|
|
25c7b0d9f4 | ||
|
|
f422ac6dcc | ||
|
|
54de4e008c | ||
|
|
65c27d2c69 | ||
|
|
53f919f8f0 | ||
|
|
c92b88e34a | ||
|
|
ed0c85a17e | ||
|
|
9fe02bba7e | ||
|
|
615557ec20 | ||
|
|
3f05ef2ae3 | ||
|
|
3022090365 | ||
|
|
798fd673e9 | ||
|
|
c056db740d | ||
|
|
a0b5e5bfa0 | ||
|
|
41d0657330 | ||
|
|
1a0cabbfd6 | ||
|
|
9b6dcc57bd | ||
|
|
6d11f9ed77 | ||
|
|
489a4d934e | ||
|
|
b17704d6ef | ||
|
|
496469ac4e | ||
|
|
c1b52615be | ||
|
|
3af9940b85 | ||
|
|
22b1277572 | ||
|
|
aff98d5ae1 | ||
|
|
4e1bb2b445 | ||
|
|
dac6e52091 | ||
|
|
8987e0ba67 | ||
|
|
9d1751ec57 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
docs/claude-relay-service/
|
||||
.codex
|
||||
|
||||
# ===================
|
||||
# Go 后端
|
||||
@@ -121,7 +122,7 @@ scripts
|
||||
.code-review-state
|
||||
#openspec/
|
||||
code-reviews/
|
||||
#AGENTS.md
|
||||
AGENTS.md
|
||||
backend/cmd/server/server
|
||||
deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
|
||||
@@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
|
||||
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
|
||||
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
@@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
|
||||
同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
|
||||
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
@@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
|
||||
エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
|
||||
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/pateway.png
Normal file
BIN
assets/partners/logos/pateway.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.0 KiB |
@@ -1 +1 @@
|
||||
0.1.117
|
||||
0.1.121
|
||||
|
||||
@@ -65,12 +65,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
@@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
@@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -231,7 +231,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
|
||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
|
||||
@@ -26,11 +26,12 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Filters *BulkUpdateAccountFilters `json:"filters"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Group string `json:"group"`
|
||||
Search string `json:"search"`
|
||||
PrivacyMode string `json:"privacy_mode"`
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
@@ -518,6 +528,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
// 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。
|
||||
// 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。
|
||||
var createdAccount *service.Account
|
||||
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
@@ -539,6 +553,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
createdAccount = account
|
||||
// Antigravity OAuth: 新账号直接设置隐私
|
||||
h.adminService.ForceAntigravityPrivacy(ctx, account)
|
||||
// OpenAI OAuth: 新账号直接设置隐私
|
||||
@@ -567,6 +582,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
// OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。
|
||||
// 探测失败不影响账号创建响应。
|
||||
h.scheduleOpenAIResponsesProbe(createdAccount)
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
@@ -627,9 +645,39 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAI APIKey: credentials 修改后重新探测上游能力(base_url/api_key 可能变更)。
|
||||
// 异步执行,探测失败不影响账号更新响应。
|
||||
if len(req.Credentials) > 0 {
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。
|
||||
//
|
||||
// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。
|
||||
// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞
|
||||
// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失,
|
||||
// 网关会按"现状即证据"默认走 Responses。
|
||||
func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) {
|
||||
if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey {
|
||||
return
|
||||
}
|
||||
if h.accountTestService == nil {
|
||||
return
|
||||
}
|
||||
accountID := account.ID
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r)
|
||||
}
|
||||
}()
|
||||
h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID)
|
||||
}()
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
// DELETE /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
@@ -1221,6 +1269,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
|
||||
}
|
||||
}
|
||||
// OpenAI APIKey 账号异步探测 /v1/responses 能力。
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
@@ -1369,6 +1419,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 && req.Filters == nil {
|
||||
response.BadRequest(c, "account_ids or filters is required")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
@@ -1394,6 +1448,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
@@ -1429,6 +1484,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
|
||||
if filters == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.BulkUpdateAccountFilters{
|
||||
Platform: filters.Platform,
|
||||
Type: filters.Type,
|
||||
Status: filters.Status,
|
||||
Group: filters.Group,
|
||||
Search: filters.Search,
|
||||
PrivacyMode: filters.PrivacyMode,
|
||||
}
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
|
||||
@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
|
||||
require.Equal(t, float64(2), data["success"])
|
||||
require.Equal(t, float64(0), data["failed"])
|
||||
}
|
||||
|
||||
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"filters": map[string]any{
|
||||
"platform": "openai",
|
||||
"type": "oauth",
|
||||
"status": "active",
|
||||
"group": "12",
|
||||
"privacy_mode": "blocked",
|
||||
"search": "bulk-target",
|
||||
},
|
||||
"schedulable": true,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
|
||||
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
|
||||
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
|
||||
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
|
||||
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
|
||||
t.Run("nil input returns nil", func(t *testing.T) {
|
||||
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
|
||||
})
|
||||
|
||||
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.NotNil(t, out)
|
||||
require.Len(t, out.Rules, 1)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, "all", out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: " ",
|
||||
Action: "pass",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "PRIORITY",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{
|
||||
{ServiceTier: "priority", Action: "filter", Scope: "all"},
|
||||
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
|
||||
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
|
||||
},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Len(t, out.Rules, 3)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
|
||||
for i := range s.apiKeys {
|
||||
if s.apiKeys[i].ID == keyID {
|
||||
s.apiKeys[i].Usage5h = 0
|
||||
s.apiKeys[i].Usage1d = 0
|
||||
s.apiKeys[i].Usage7d = 0
|
||||
s.apiKeys[i].Window5hStart = nil
|
||||
s.apiKeys[i].Window1dStart = nil
|
||||
s.apiKeys[i].Window7dStart = nil
|
||||
k := s.apiKeys[i]
|
||||
return &k, nil
|
||||
}
|
||||
}
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
291
backend/internal/handler/admin/affiliate_handler.go
Normal file
291
backend/internal/handler/admin/affiliate_handler.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AffiliateHandler handles admin affiliate (邀请返利) management:
|
||||
// listing users with custom settings, updating per-user invite codes
|
||||
// and exclusive rebate rates, and batch operations.
|
||||
type AffiliateHandler struct {
|
||||
affiliateService *service.AffiliateService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewAffiliateHandler creates a new admin affiliate handler.
|
||||
func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
|
||||
return &AffiliateHandler{
|
||||
affiliateService: affiliateService,
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// ListUsers returns paginated users with custom affiliate settings.
|
||||
// GET /api/v1/admin/affiliates/users
|
||||
func (h *AffiliateHandler) ListUsers(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
search := c.Query("search")
|
||||
|
||||
entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
|
||||
Search: search,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, entries, total, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateUserSettings updates a user's affiliate settings.
|
||||
// PUT /api/v1/admin/affiliates/users/:user_id
|
||||
//
|
||||
// Both fields are optional and applied independently.
|
||||
type UpdateAffiliateUserRequest struct {
|
||||
AffCode *string `json:"aff_code"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
|
||||
// ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
|
||||
// Used to disambiguate from "field not provided".
|
||||
ClearRebateRate bool `json:"clear_rebate_rate"`
|
||||
}
|
||||
|
||||
func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAffiliateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.AffCode != nil {
|
||||
if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.ClearRebateRate {
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
} else if req.AffRebateRatePercent != nil {
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"user_id": userID})
|
||||
}
|
||||
|
||||
// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
|
||||
// the exclusive rebate rate AND regenerates the invite code as a new system
|
||||
// random one. Conceptually this "removes the user from the custom list".
|
||||
//
|
||||
// Both writes happen in this handler; failure of one leaves the other applied,
|
||||
// but the operation is idempotent so the admin can re-run it safely.
|
||||
// DELETE /api/v1/admin/affiliates/users/:user_id
|
||||
func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"user_id": userID})
|
||||
}
|
||||
|
||||
// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
|
||||
//
|
||||
// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
|
||||
// ignored). Otherwise aff_rebate_rate_percent is required and applied to
|
||||
// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
|
||||
// can't distinguish a missing field from `null`, and a silent clear from a
|
||||
// frontend that forgot to include the rate would be a footgun.
|
||||
//
|
||||
// POST /api/v1/admin/affiliates/users/batch-rate
|
||||
type BatchSetRateRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
|
||||
Clear bool `json:"clear"`
|
||||
}
|
||||
|
||||
func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
|
||||
var req BatchSetRateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.UserIDs) == 0 {
|
||||
response.BadRequest(c, "user_ids cannot be empty")
|
||||
return
|
||||
}
|
||||
if !req.Clear && req.AffRebateRatePercent == nil {
|
||||
response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
|
||||
return
|
||||
}
|
||||
rate := req.AffRebateRatePercent
|
||||
if req.Clear {
|
||||
rate = nil
|
||||
}
|
||||
if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"affected": len(req.UserIDs)})
|
||||
}
|
||||
|
||||
// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
|
||||
// shared with the frontend's add-custom-user picker.
|
||||
type AffiliateUserSummary struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// LookupUsers searches users by email/username for the "add custom user" modal.
|
||||
// GET /api/v1/admin/affiliates/users/lookup?q=
|
||||
func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
|
||||
keyword := c.Query("q")
|
||||
if keyword == "" {
|
||||
response.Success(c, []AffiliateUserSummary{})
|
||||
return
|
||||
}
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
result := make([]AffiliateUserSummary, len(users))
|
||||
for i, u := range users {
|
||||
result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetUserOverview returns one user's affiliate overview.
|
||||
// GET /api/v1/admin/affiliates/users/:user_id/overview
|
||||
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, overview)
|
||||
}
|
||||
|
||||
// ListInviteRecords returns all inviter-invitee relationships.
|
||||
// GET /api/v1/admin/affiliates/invites
|
||||
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListRebateRecords returns all order-level affiliate rebate records.
|
||||
// GET /api/v1/admin/affiliates/rebates
|
||||
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
|
||||
// GET /api/v1/admin/affiliates/transfers
|
||||
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
|
||||
filter := service.AffiliateRecordFilter{
|
||||
Search: c.Query("search"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
SortBy: c.Query("sort_by"),
|
||||
SortDesc: c.Query("sort_order") != "asc",
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
userTZ := c.Query("timezone")
|
||||
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
|
||||
filter.StartAt = t
|
||||
}
|
||||
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
|
||||
filter.EndAt = t
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
|
||||
return &end
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
|
||||
}
|
||||
}
|
||||
|
||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
|
||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
|
||||
type AdminUpdateAPIKeyGroupRequest struct {
|
||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
|
||||
}
|
||||
|
||||
// UpdateGroup handles updating an API key's group binding
|
||||
// UpdateGroup handles updating an API key's admin-managed fields.
|
||||
// PUT /api/v1/admin/api-keys/:id
|
||||
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var resetKey *service.APIKey
|
||||
if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
|
||||
resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if resetKey != nil && req.GroupID == nil {
|
||||
result.APIKey = resetKey
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
APIKey *dto.APIKey `json:"api_key"`
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
|
||||
require.Nil(t, resp.Data.APIKey.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
|
||||
svc := newStubAdminService()
|
||||
now := time.Now()
|
||||
svc.apiKeys[0].Usage5h = 1.2
|
||||
svc.apiKeys[0].Usage1d = 3.4
|
||||
svc.apiKeys[0].Usage7d = 5.6
|
||||
svc.apiKeys[0].Window5hStart = &now
|
||||
svc.apiKeys[0].Window1dStart = &now
|
||||
svc.apiKeys[0].Window7dStart = &now
|
||||
router := setupAPIKeyHandler(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
APIKey struct {
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
} `json:"api_key"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Zero(t, resp.Data.APIKey.Usage5h)
|
||||
require.Zero(t, resp.Data.APIKey.Usage1d)
|
||||
require.Zero(t, resp.Data.APIKey.Usage7d)
|
||||
require.Nil(t, resp.Data.APIKey.Window5hStart)
|
||||
require.Nil(t, resp.Data.APIKey.Window1dStart)
|
||||
require.Nil(t, resp.Data.APIKey.Window7dStart)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
|
||||
svc := &failingUpdateGroupService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
|
||||
@@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -206,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: settings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||
@@ -242,10 +246,54 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
}
|
||||
|
||||
// OpenAI fast policy (stored under a dedicated setting key)
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
|
||||
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = dto.OpenAIFastPolicyRule(r)
|
||||
}
|
||||
return &dto.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
|
||||
//
|
||||
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
|
||||
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
|
||||
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
|
||||
// service.SetOpenAIFastPolicySettings 负责合法值校验。
|
||||
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = service.OpenAIFastPolicyRule(r)
|
||||
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
|
||||
if tier == "" {
|
||||
tier = service.OpenAIFastTierAny
|
||||
}
|
||||
rules[i].ServiceTier = tier
|
||||
}
|
||||
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
@@ -340,6 +388,9 @@ type UpdateSettingsRequest struct {
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
@@ -391,9 +442,10 @@ type UpdateSettingsRequest struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Payment visible method routing
|
||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||
@@ -441,6 +493,12 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// Available Channels feature switch (user-facing)
|
||||
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -480,6 +538,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if affiliateRebateRate > service.AffiliateRebateRateMax {
|
||||
affiliateRebateRate = service.AffiliateRebateRateMax
|
||||
}
|
||||
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
|
||||
if req.AffiliateRebateFreezeHours != nil {
|
||||
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
|
||||
}
|
||||
if affiliateRebateFreezeHours < 0 {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
|
||||
if req.AffiliateRebateDurationDays != nil {
|
||||
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
|
||||
}
|
||||
if affiliateRebateDurationDays < 0 {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
|
||||
}
|
||||
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
|
||||
if req.AffiliateRebatePerInviteeCap != nil {
|
||||
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
|
||||
}
|
||||
if affiliateRebatePerInviteeCap < 0 {
|
||||
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
|
||||
if req.TableDefaultPageSize <= 0 {
|
||||
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
|
||||
@@ -1132,6 +1217,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -1187,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.EnableCCHSigning
|
||||
}(),
|
||||
EnableAnthropicCacheTTL1hInjection: func() bool {
|
||||
if req.EnableAnthropicCacheTTL1hInjection != nil {
|
||||
return *req.EnableAnthropicCacheTTL1hInjection
|
||||
}
|
||||
return previousSettings.EnableAnthropicCacheTTL1hInjection
|
||||
}(),
|
||||
PaymentVisibleMethodAlipaySource: func() string {
|
||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||
@@ -1265,6 +1359,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.AvailableChannelsEnabled
|
||||
}(),
|
||||
AffiliateEnabled: func() bool {
|
||||
if req.AffiliateEnabled != nil {
|
||||
return *req.AffiliateEnabled
|
||||
}
|
||||
return previousSettings.AffiliateEnabled
|
||||
}(),
|
||||
}
|
||||
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
@@ -1303,6 +1403,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update OpenAI fast policy (stored under dedicated key, only when provided).
|
||||
if req.OpenAIFastPolicySettings != nil {
|
||||
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update payment configuration (integrated into system settings).
|
||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||
@@ -1447,6 +1555,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -1467,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||
@@ -1502,6 +1614,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
}
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
@@ -1755,6 +1874,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AffiliateRebateRate != after.AffiliateRebateRate {
|
||||
changed = append(changed, "affiliate_rebate_rate")
|
||||
}
|
||||
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
|
||||
changed = append(changed, "affiliate_rebate_freeze_hours")
|
||||
}
|
||||
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
|
||||
changed = append(changed, "affiliate_rebate_duration_days")
|
||||
}
|
||||
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
|
||||
changed = append(changed, "affiliate_rebate_per_invitee_cap")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
@@ -1830,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EnableCCHSigning != after.EnableCCHSigning {
|
||||
changed = append(changed, "enable_cch_signing")
|
||||
}
|
||||
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
|
||||
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
|
||||
}
|
||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||
changed = append(changed, "payment_visible_method_alipay_source")
|
||||
}
|
||||
@@ -1870,6 +2001,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
|
||||
changed = append(changed, "available_channels_enabled")
|
||||
}
|
||||
if before.AffiliateEnabled != after.AffiliateEnabled {
|
||||
changed = append(changed, "affiliate_enabled")
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
|
||||
@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
if s.values != nil {
|
||||
if value, ok := s.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
|
||||
@@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||
// GET /api/v1/admin/users/:id/balance-history
|
||||
// Query params:
|
||||
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
|
||||
@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
|
||||
VerifyCode string `json:"verify_code,omitempty"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
strings.TrimSpace(req.AffCode),
|
||||
); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
|
||||
@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
|
||||
|
||||
type completeOIDCOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
|
||||
|
||||
type completeWeChatOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -106,11 +106,14 @@ type SystemSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -139,9 +142,10 @@ type SystemSettings struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||
@@ -192,6 +196,12 @@ type SystemSettings struct {
|
||||
|
||||
// Available Channels feature switch (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -244,6 +254,8 @@ type PublicSettings struct {
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
@@ -286,6 +298,22 @@ type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
|
||||
type OpenAIFastPolicyRule struct {
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
ModelWhitelist []string `json:"model_whitelist,omitempty"`
|
||||
FallbackAction string `json:"fallback_action,omitempty"`
|
||||
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
|
||||
type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -262,6 +262,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// [DEBUG-STICKY] 打印会话 hash 生成结果
|
||||
reqLog.Info("sticky.session_hash_generated",
|
||||
zap.String("session_hash", sessionHash),
|
||||
zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
|
||||
)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
platform := ""
|
||||
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
@@ -278,6 +284,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
// [DEBUG-STICKY] 打印粘性会话查询结果
|
||||
reqLog.Info("sticky.cache_lookup",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("bound_account_id", sessionBoundAccountID),
|
||||
)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
@@ -286,6 +297,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
} else {
|
||||
reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
@@ -536,6 +549,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
reqLog.Info("sticky.selecting_account",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("has_bound_session", hasBoundSession),
|
||||
zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
|
||||
)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
@@ -569,6 +588,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// [DEBUG-STICKY] 打印账号选择结果
|
||||
reqLog.Info("sticky.account_selected",
|
||||
zap.Int64("selected_account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.Bool("slot_acquired", selection.Acquired),
|
||||
zap.Bool("has_wait_plan", selection.WaitPlan != nil),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
|
||||
)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
@@ -635,6 +664,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
reqLog.Info("sticky.bind_after_wait",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("account_id", account.ID),
|
||||
)
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
@@ -829,6 +862,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定粘性会话(成功转发后绑定/刷新)
|
||||
// - 无现有绑定(首次请求):创建绑定
|
||||
// - 选中账号与粘性账号一致:刷新 TTL
|
||||
// - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
|
||||
// 下次请求粘性账号恢复后仍可命中
|
||||
if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ type AdminHandlers struct {
|
||||
ChannelMonitor *admin.ChannelMonitorHandler
|
||||
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
||||
Payment *admin.PaymentHandler
|
||||
Affiliate *admin.AffiliateHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -276,7 +277,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
@@ -299,3 +300,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
|
||||
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
|
||||
// has been probed to not support the Responses API, the request is
|
||||
// forwarded directly to /v1/chat/completions — not through the default
|
||||
// CC→Responses conversion path.
|
||||
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
|
||||
if account != nil && account.Type == service.AccountTypeAPIKey &&
|
||||
!openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return "/v1/chat/completions"
|
||||
}
|
||||
return GetUpstreamEndpoint(c, account.Platform)
|
||||
}
|
||||
|
||||
@@ -1233,6 +1233,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||
})
|
||||
|
||||
require.NotNil(t, got.log.UserAgent)
|
||||
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||
require.True(t, got.log.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||
channelMapping: map[string]string{
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||
"上游首帧应使用渠道映射后的模型")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||
userAgent: testStringPtr(""),
|
||||
})
|
||||
|
||||
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogCase struct {
|
||||
firstPayload string
|
||||
userAgent *string
|
||||
channelMapping map[string]string
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogResult struct {
|
||||
log *service.UsageLog
|
||||
upstreamFirstPayload []byte
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
if s.account.Platform != platform {
|
||||
return nil, nil
|
||||
}
|
||||
return []service.Account{s.account}, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID != id {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.account
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if s.created != nil {
|
||||
s.created <- log
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||
service.ChannelRepository
|
||||
channels []service.Channel
|
||||
groupPlatforms map[int64]string
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
return s.channels, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
out := make(map[int64]string, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||
out[groupID] = platform
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstreamPayloadCh := make(chan []byte, 1)
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
upstreamErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr != nil {
|
||||
upstreamErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||
return
|
||||
}
|
||||
upstreamPayloadCh <- payload
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
))
|
||||
cancelWrite()
|
||||
if writeErr != nil {
|
||||
upstreamErrCh <- writeErr
|
||||
return
|
||||
}
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
upstreamErrCh <- nil
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
groupID := int64(4201)
|
||||
account := service.Account{
|
||||
ID: 9901,
|
||||
Name: "openai-ws-passthrough-usage-e2e",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": upstreamServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||
|
||||
var channelSvc *service.ChannelService
|
||||
if len(tc.channelMapping) > 0 {
|
||||
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||
channels: []service.Channel{{
|
||||
ID: 7701,
|
||||
Name: "openai-ws-e2e-channel",
|
||||
Status: service.StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||
}},
|
||||
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
nil,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1801,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
if tc.userAgent != nil {
|
||||
headers.Set("User-Agent", *tc.userAgent)
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
var usageLog *service.UsageLog
|
||||
select {
|
||||
case usageLog = <-usageRepo.created:
|
||||
require.NotNil(t, usageLog)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||
}
|
||||
|
||||
var upstreamFirstPayload []byte
|
||||
select {
|
||||
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case upstreamErr := <-upstreamErrCh:
|
||||
require.NoError(t, upstreamErr)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 结束超时")
|
||||
}
|
||||
|
||||
return openAIResponsesWSUsageLogResult{
|
||||
log: usageLog,
|
||||
upstreamFirstPayload: upstreamFirstPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func testStringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := ""
|
||||
if parsed.Multipart {
|
||||
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
||||
} else {
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
|
||||
@@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
channelMonitorHandler *admin.ChannelMonitorHandler,
|
||||
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
||||
paymentHandler *admin.PaymentHandler,
|
||||
affiliateHandler *admin.AffiliateHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
|
||||
ChannelMonitor: channelMonitorHandler,
|
||||
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
||||
Payment: paymentHandler,
|
||||
Affiliate: affiliateHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,6 +171,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewChannelMonitorHandler,
|
||||
admin.NewChannelMonitorRequestTemplateHandler,
|
||||
admin.NewPaymentHandler,
|
||||
admin.NewAffiliateHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
easypayStatusPaid = 1
|
||||
easypayHTTPTimeout = 10 * time.Second
|
||||
maxEasypayResponseSize = 1 << 20 // 1MB
|
||||
maxEasypayErrorSummary = 512
|
||||
tradeStatusSuccess = "TRADE_SUCCESS"
|
||||
signTypeMD5 = "MD5"
|
||||
paymentModePopup = "popup"
|
||||
@@ -42,17 +43,55 @@ type EasyPay struct {
|
||||
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
|
||||
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
|
||||
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
|
||||
if config[k] == "" {
|
||||
if strings.TrimSpace(config[k]) == "" {
|
||||
return nil, fmt.Errorf("easypay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
cfg := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
cfg[k] = v
|
||||
}
|
||||
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
|
||||
return &EasyPay{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
config: cfg,
|
||||
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeEasyPayAPIBase(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
|
||||
}
|
||||
|
||||
func trimEasyPayEndpointPath(path string) string {
|
||||
path = strings.TrimRight(strings.TrimSpace(path), "/")
|
||||
lower := strings.ToLower(path)
|
||||
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
|
||||
if strings.HasSuffix(lower, endpoint) {
|
||||
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (e *EasyPay) apiBase() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeEasyPayAPIBase(e.config["apiBase"])
|
||||
}
|
||||
|
||||
func (e *EasyPay) Name() string { return "EasyPay" }
|
||||
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
|
||||
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
|
||||
@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
|
||||
for k, v := range params {
|
||||
q.Set(k, v)
|
||||
}
|
||||
base := strings.TrimRight(e.config["apiBase"], "/")
|
||||
payURL := base + "/submit.php?" + q.Encode()
|
||||
payURL := e.apiBase() + "/submit.php?" + q.Encode()
|
||||
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
|
||||
}
|
||||
|
||||
@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
|
||||
params["sign"] = easyPaySign(params, e.config["pkey"])
|
||||
params["sign_type"] = signTypeMD5
|
||||
|
||||
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay create: %w", err)
|
||||
}
|
||||
@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
|
||||
"act": "order", "pid": e.config["pid"],
|
||||
"key": e.config["pkey"], "out_trade_no": tradeNo,
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay query: %w", err)
|
||||
}
|
||||
@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
|
||||
}
|
||||
|
||||
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"],
|
||||
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
|
||||
attempts := e.refundAttempts(req)
|
||||
if len(attempts) == 0 {
|
||||
return nil, fmt.Errorf("easypay refund missing order identifier")
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund: %w", err)
|
||||
var firstErr error
|
||||
for i, attempt := range attempts {
|
||||
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund request: %w", err)
|
||||
}
|
||||
if err := parseEasyPayRefundResponse(status, body); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
type easyPayRefundAttempt struct {
|
||||
params map[string]string
|
||||
refundID string
|
||||
}
|
||||
|
||||
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
|
||||
base := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
|
||||
}
|
||||
var attempts []easyPayRefundAttempt
|
||||
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["out_trade_no"] = orderID
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
|
||||
}
|
||||
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["trade_no"] = tradeNo
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isEasyPayRefundOrderNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
lower := strings.ToLower(msg)
|
||||
return strings.Contains(msg, "订单编号不存在") ||
|
||||
strings.Contains(msg, "订单不存在") ||
|
||||
strings.Contains(lower, "order not found") ||
|
||||
strings.Contains(lower, "not exist")
|
||||
}
|
||||
|
||||
func parseEasyPayRefundResponse(status int, body []byte) error {
|
||||
summary := summarizeEasyPayResponse(body)
|
||||
if status < http.StatusOK || status >= http.StatusMultipleChoices {
|
||||
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(trimmed)
|
||||
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
|
||||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Code any `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse refund: %w", err)
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
if resp.Code != easypayCodeSuccess {
|
||||
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
|
||||
if !easyPayResponseCodeIsSuccess(resp.Code) {
|
||||
msg := strings.TrimSpace(resp.Msg)
|
||||
if msg == "" {
|
||||
msg = summary
|
||||
}
|
||||
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func easyPayResponseCodeIsSuccess(code any) bool {
|
||||
switch v := code.(type) {
|
||||
case float64:
|
||||
return int(v) == easypayCodeSuccess
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
return err == nil && n == easypayCodeSuccess
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeEasyPayResponse(body []byte) string {
|
||||
summary := strings.Join(strings.Fields(string(body)), " ")
|
||||
if summary == "" {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(summary) > maxEasypayErrorSummary {
|
||||
return summary[:maxEasypayErrorSummary] + "..."
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
}
|
||||
|
||||
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
|
||||
body, _, err := e.postRaw(ctx, endpoint, params)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
client := e.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: easypayHTTPTimeout}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func easyPaySign(params map[string]string, pkey string) string {
|
||||
|
||||
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestNormalizeEasyPayAPIBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
|
||||
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotPath string
|
||||
var gotQuery url.Values
|
||||
var gotForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotQuery = r.URL.Query()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForm = r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
|
||||
t.Fatalf("Refund response = %+v, want success", resp)
|
||||
}
|
||||
if gotPath != "/api.php" {
|
||||
t.Fatalf("refund path = %q, want /api.php", gotPath)
|
||||
}
|
||||
if gotQuery.Get("act") != "refund" {
|
||||
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
|
||||
}
|
||||
for key, want := range map[string]string{
|
||||
"pid": "pid-1",
|
||||
"key": "pkey-1",
|
||||
"out_trade_no": "out-456",
|
||||
"money": "1.50",
|
||||
} {
|
||||
if got := gotForm.Get(key); got != want {
|
||||
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
|
||||
}
|
||||
}
|
||||
if got := gotForm.Get("trade_no"); got != "" {
|
||||
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotForms []url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api.php" {
|
||||
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("act") != "refund" {
|
||||
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForms = append(gotForms, r.PostForm)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if len(gotForms) == 1 {
|
||||
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
|
||||
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
|
||||
}
|
||||
if len(gotForms) != 2 {
|
||||
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
|
||||
}
|
||||
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
|
||||
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[0].Get("trade_no"); got != "" {
|
||||
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
|
||||
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
|
||||
}
|
||||
if got := gotForms[1].Get("out_trade_no"); got != "" {
|
||||
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundResponseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
|
||||
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
|
||||
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
|
||||
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, _ = w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL)
|
||||
_, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Refund returned nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
|
||||
t.Helper()
|
||||
|
||||
provider, err := NewEasyPay("test-instance", map[string]string{
|
||||
"pid": "pid-1",
|
||||
"pkey": "pkey-1",
|
||||
"apiBase": apiBase,
|
||||
"notifyUrl": "https://example.com/notify",
|
||||
"returnUrl": "https://example.com/return",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewEasyPay: %v", err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Cached response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 3318, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached_clamp",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 5,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 150,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 0, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
@@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "tool_use", anth.Content[1].Type)
|
||||
assert.Equal(t, "call_1", anth.Content[1].ID)
|
||||
assert.Equal(t, "get_weather", anth.Content[1].Name)
|
||||
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_read",
|
||||
Model: "gpt-5.5",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_read",
|
||||
Name: "Read",
|
||||
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.Equal(t, "tool_use", anth.Content[0].Type)
|
||||
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_other",
|
||||
Model: "gpt-5.5",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_other",
|
||||
Name: "Search",
|
||||
Arguments: `{"query":""}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
|
||||
@@ -343,6 +434,75 @@ func TestStreamingTextOnly(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
assert.Equal(t, 12, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 4, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, 3318, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingToolCall(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@@ -393,6 +553,41 @@ func TestStreamingToolCall(t *testing.T) {
|
||||
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 0,
|
||||
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
}, state)
|
||||
assert.Len(t, events, 0)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
OutputIndex: 0,
|
||||
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
|
||||
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
|
||||
assert.Equal(t, "content_block_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingReasoning(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@@ -835,9 +1030,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
assert.NotContains(t, tc, "function")
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
|
||||
req := &ResponsesRequest{
|
||||
Model: "gpt-5.2",
|
||||
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ResponsesToAnthropicRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc map[string]string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "tool", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
|
||||
req := &ResponsesRequest{
|
||||
Model: "gpt-5.2",
|
||||
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
|
||||
}
|
||||
|
||||
resp, err := ResponsesToAnthropicRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc map[string]string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "tool", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
// {"type":"auto"} → "auto"
|
||||
// {"type":"any"} → "required"
|
||||
// {"type":"none"} → "none"
|
||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
|
||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
|
||||
return json.Marshal("none")
|
||||
case "tool":
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": tc.Name},
|
||||
"type": "function",
|
||||
"name": tc.Name,
|
||||
})
|
||||
default:
|
||||
// Pass through unknown types as-is
|
||||
|
||||
@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
assert.NotContains(t, tc, "function")
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
@@ -718,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
// {"name":"X"} → {"type":"function","name":"X"}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
"type": "function",
|
||||
"name": obj.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: json.RawMessage(item.Arguments),
|
||||
Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
|
||||
})
|
||||
case "web_search_call":
|
||||
toolUseID := "srvtoolu_" + item.ID
|
||||
@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
|
||||
if usage == nil {
|
||||
return AnthropicUsage{}
|
||||
}
|
||||
|
||||
cachedTokens := 0
|
||||
if usage.InputTokensDetails != nil {
|
||||
cachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens - cachedTokens
|
||||
if inputTokens < 0 {
|
||||
inputTokens = 0
|
||||
}
|
||||
|
||||
return AnthropicUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
@@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
|
||||
if name != "Read" || raw == "" {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
var input map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &input); err != nil {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
delete(input, "pages")
|
||||
sanitized, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
CurrentToolName string
|
||||
CurrentToolArgs string
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
@@ -165,14 +205,16 @@ func ResponsesEventToAnthropicEvents(
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||
case "response.function_call_arguments.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
return resToAnthHandleFuncArgsDone(evt, state)
|
||||
case "response.output_item.done":
|
||||
return resToAnthHandleOutputItemDone(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
@@ -262,6 +304,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "tool_use"
|
||||
state.CurrentToolName = evt.Item.Name
|
||||
state.CurrentToolArgs = ""
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
@@ -342,6 +386,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
return nil
|
||||
}
|
||||
|
||||
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
|
||||
state.CurrentToolArgs += evt.Delta
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -357,6 +406,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
|
||||
return resToAnthHandleBlockDone(state)
|
||||
}
|
||||
|
||||
raw := evt.Arguments
|
||||
if raw == "" {
|
||||
raw = state.CurrentToolArgs
|
||||
}
|
||||
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
|
||||
if len(sanitized) == 0 {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events := []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(sanitized),
|
||||
},
|
||||
}}
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
@@ -466,11 +542,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
|
||||
state.InputTokens = usage.InputTokens
|
||||
state.OutputTokens = usage.OutputTokens
|
||||
state.CacheReadInputTokens = usage.CacheReadInputTokens
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
@@ -509,6 +584,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = false
|
||||
state.ContentBlockIndex++
|
||||
state.CurrentToolName = ""
|
||||
state.CurrentToolArgs = ""
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
|
||||
@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
|
||||
// "auto" → {"type":"auto"}
|
||||
// "required" → {"type":"any"}
|
||||
// "none" → {"type":"none"}
|
||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
|
||||
// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
|
||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
|
||||
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try as string first
|
||||
var s string
|
||||
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
|
||||
// Try as object with type=function
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
|
||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
|
||||
name := strings.TrimSpace(tc.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(tc.Function.Name)
|
||||
}
|
||||
if name == "" {
|
||||
return raw, nil
|
||||
}
|
||||
return json.Marshal(map[string]string{
|
||||
"type": "tool",
|
||||
"name": tc.Function.Name,
|
||||
"name": name,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
|
||||
@@ -2,16 +2,28 @@ package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
const (
|
||||
requestBodyReadInitCap = 512
|
||||
requestBodyReadMaxInitCap = 1 << 20
|
||||
// maxDecompressedBodySize limits the decompressed request body to 64 MB
|
||||
// to prevent decompression bomb attacks.
|
||||
maxDecompressedBodySize = 64 << 20
|
||||
)
|
||||
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
|
||||
// on content length, transparently decoding any Content-Encoding the upstream
|
||||
// client used to compress the body (zstd, gzip, deflate).
|
||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
raw := buf.Bytes()
|
||||
|
||||
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
|
||||
if enc == "" || enc == "identity" {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
decoded, err := decompressRequestBody(enc, raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
|
||||
}
|
||||
|
||||
req.Header.Del("Content-Encoding")
|
||||
req.Header.Del("Content-Length")
|
||||
req.ContentLength = int64(len(decoded))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
|
||||
switch encoding {
|
||||
case "zstd":
|
||||
dec, err := zstd.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dec.Close()
|
||||
return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
|
||||
case "gzip", "x-gzip":
|
||||
gr, err := gzip.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = gr.Close() }()
|
||||
return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
|
||||
case "deflate":
|
||||
zr, err := zlib.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = zr.Close() }()
|
||||
return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
|
||||
default:
|
||||
return nil, errors.New("unsupported Content-Encoding")
|
||||
}
|
||||
}
|
||||
|
||||
143
backend/internal/pkg/httputil/body_test.go
Normal file
143
backend/internal/pkg/httputil/body_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
|
||||
|
||||
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
|
||||
t.Helper()
|
||||
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
if encoding != "" {
|
||||
req.Header.Set("Content-Encoding", encoding)
|
||||
}
|
||||
req.ContentLength = int64(len(body))
|
||||
return req
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
|
||||
enc, _ := zstd.NewWriter(nil)
|
||||
compressed := enc.EncodeAll([]byte(samplePayload), nil)
|
||||
_ = enc.Close()
|
||||
|
||||
req := newRequestWithBody(t, compressed, "zstd")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
if req.Header.Get("Content-Encoding") != "" {
|
||||
t.Fatalf("Content-Encoding should be cleared after decoding")
|
||||
}
|
||||
if req.ContentLength != int64(len(samplePayload)) {
|
||||
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
if _, err := gw.Write([]byte(samplePayload)); err != nil {
|
||||
t.Fatalf("gzip write: %v", err)
|
||||
}
|
||||
if err := gw.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
|
||||
req := newRequestWithBody(t, buf.Bytes(), "gzip")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
zw := zlib.NewWriter(&buf)
|
||||
if _, err := zw.Write([]byte(samplePayload)); err != nil {
|
||||
t.Fatalf("zlib write: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("zlib close: %v", err)
|
||||
}
|
||||
|
||||
req := newRequestWithBody(t, buf.Bytes(), "deflate")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "br")
|
||||
_, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported encoding, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "br") {
|
||||
t.Fatalf("error should mention encoding, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
|
||||
_, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for corrupt zstd body, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil body, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "identity")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。
|
||||
//
|
||||
// 背景:sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游
|
||||
// (DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions,
|
||||
// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到
|
||||
// /v1/responses,导致兼容上游 404。
|
||||
//
|
||||
// 本包提供基于"账号探测标记"的能力判定,配合
|
||||
// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性
|
||||
// 探测并落标。
|
||||
//
|
||||
// 设计取舍:
|
||||
// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于
|
||||
// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs)
|
||||
// - 标记缺失时默认 true(即"走 Responses"),保持与重构前老代码完全一致的存量
|
||||
// 账号行为("现状即证据"原则;详见
|
||||
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
|
||||
package openai_compat
|
||||
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
|
||||
//
|
||||
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
|
||||
type AccountResponsesSupport int
|
||||
|
||||
const (
|
||||
// ResponsesSupportUnknown 表示账号尚未完成能力探测(extra 字段缺失)。
|
||||
// 上游路由层应按"现状即证据"原则默认走 Responses,保持与重构前一致。
|
||||
ResponsesSupportUnknown AccountResponsesSupport = iota
|
||||
|
||||
// ResponsesSupportYes 探测确认上游支持 /v1/responses。
|
||||
ResponsesSupportYes
|
||||
|
||||
// ResponsesSupportNo 探测确认上游不支持 /v1/responses,应走
|
||||
// /v1/chat/completions 直转路径。
|
||||
ResponsesSupportNo
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
|
||||
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
|
||||
const ExtraKeyResponsesSupported = "openai_responses_supported"
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
|
||||
//
|
||||
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
|
||||
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
|
||||
func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
|
||||
if extra == nil {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
v, ok := extra[ExtraKeyResponsesSupported]
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
supported, ok := v.(bool)
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
if supported {
|
||||
return ResponsesSupportYes
|
||||
}
|
||||
return ResponsesSupportNo
|
||||
}
|
||||
|
||||
// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求
|
||||
// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。
|
||||
//
|
||||
// 返回 true 的两种情况:
|
||||
// 1. 账号已探测确认支持 Responses
|
||||
// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为
|
||||
//
|
||||
// 仅当账号已探测且确认不支持时返回 false,此时调用方应走 CC 直转路径
|
||||
// (详见 internal/service/openai_gateway_chat_completions_raw.go)。
|
||||
func ShouldUseResponsesAPI(extra map[string]any) bool {
|
||||
return ResolveResponsesSupport(extra) != ResponsesSupportNo
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package openai_compat
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveResponsesSupport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want AccountResponsesSupport
|
||||
}{
|
||||
{"nil extra", nil, ResponsesSupportUnknown},
|
||||
{"empty extra", map[string]any{}, ResponsesSupportUnknown},
|
||||
{"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown},
|
||||
{"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
|
||||
{"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
|
||||
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
|
||||
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
|
||||
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ResolveResponsesSupport(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want bool
|
||||
}{
|
||||
// 关键不变量:未探测必须返回 true(保留旧行为)
|
||||
{"unknown defaults to true (preserve old behavior)", nil, true},
|
||||
{"unknown empty defaults to true", map[string]any{}, true},
|
||||
{"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true},
|
||||
|
||||
// 已探测:标记决定
|
||||
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
|
||||
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShouldUseResponsesAPI(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,34 @@ const (
|
||||
|
||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||
|
||||
const affiliateUserOverviewSQL = `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.aff_code,
|
||||
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
|
||||
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
|
||||
ua.aff_count,
|
||||
COALESCE(rebated.rebated_invitee_count, 0),
|
||||
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM user_affiliates ua
|
||||
JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND source_user_id IS NOT NULL
|
||||
GROUP BY user_id
|
||||
) rebated ON rebated.user_id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
|
||||
GROUP BY user_id
|
||||
) matured ON matured.user_id = ua.user_id
|
||||
WHERE ua.user_id = $1
|
||||
LIMIT 1`
|
||||
|
||||
type affiliateQueryExecer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
@@ -86,17 +114,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var applied bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
res, err := txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
|
||||
amount, inviterID,
|
||||
)
|
||||
// freezeHours > 0: add to frozen quota; == 0: add to available quota directly
|
||||
var updateSQL string
|
||||
if freezeHours > 0 {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
} else {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -106,10 +138,19 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
if freezeHours > 0 {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
applied = true
|
||||
@@ -121,6 +162,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID);
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx,
|
||||
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
|
||||
inviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
var total float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return total, rows.Close()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
|
||||
var thawed float64
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
var err error
|
||||
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
|
||||
return err
|
||||
})
|
||||
return thawed, err
|
||||
}
|
||||
|
||||
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
|
||||
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH matured AS (
|
||||
UPDATE user_affiliate_ledger
|
||||
SET frozen_until = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND frozen_until IS NOT NULL
|
||||
AND frozen_until <= NOW()
|
||||
RETURNING amount
|
||||
)
|
||||
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("thaw frozen quota: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var thawed float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&thawed); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if thawed <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err = txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_quota = aff_quota + $1,
|
||||
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, thawed, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("move thawed quota: %w", err)
|
||||
}
|
||||
return thawed, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
var transferred float64
|
||||
var newBalance float64
|
||||
@@ -130,6 +241,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
// Thaw any matured frozen quota before transfer.
|
||||
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
|
||||
return fmt.Errorf("thaw before transfer: %w", err)
|
||||
}
|
||||
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH claimed AS (
|
||||
SELECT aff_quota::double precision AS amount
|
||||
@@ -187,9 +303,32 @@ FROM cleared`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (
|
||||
user_id,
|
||||
action,
|
||||
amount,
|
||||
source_user_id,
|
||||
balance_after,
|
||||
aff_quota_after,
|
||||
aff_frozen_quota_after,
|
||||
aff_history_quota_after,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
|
||||
userID,
|
||||
transferred,
|
||||
snapshot.BalanceAfter,
|
||||
snapshot.AvailableQuotaAfter,
|
||||
snapshot.FrozenQuotaAfter,
|
||||
snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||
}
|
||||
|
||||
@@ -211,10 +350,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64,
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.created_at
|
||||
ua.created_at,
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
|
||||
FROM user_affiliates ua
|
||||
LEFT JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = $1
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
WHERE ua.inviter_id = $1
|
||||
GROUP BY ua.user_id, u.email, u.username, ua.created_at
|
||||
ORDER BY ua.created_at DESC
|
||||
LIMIT $2`, inviterID, limit)
|
||||
if err != nil {
|
||||
@@ -226,7 +371,7 @@ LIMIT $2`, inviterID, limit)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInvitee
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.CreatedAt = &createdAt
|
||||
@@ -238,6 +383,349 @@ LIMIT $2`, inviterID, limit)
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
|
||||
})
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
`+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"aff_code": "inviter_aff.aff_code",
|
||||
"total_rebate": "total_rebate",
|
||||
"created_at": "ua.created_at",
|
||||
}, "ua.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ua.inviter_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ua.user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
COALESCE(inviter_aff.aff_code, ''),
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
|
||||
ua.created_at
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = ua.inviter_id
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
`+where+`
|
||||
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateInviteRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInviteRecord
|
||||
if err := rows.Scan(
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.AffCode,
|
||||
&item.TotalRebate,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN payment_orders po ON po.id = ual.source_order_id
|
||||
JOIN users invitee ON invitee.id = ual.source_user_id
|
||||
JOIN users inviter ON inviter.id = ual.user_id
|
||||
WHERE ual.action = 'accrue'
|
||||
AND ual.source_order_id IS NOT NULL`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"order": "po.id",
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"order_amount": "po.amount",
|
||||
"pay_amount": "po.pay_amount",
|
||||
"rebate_amount": "ual.amount",
|
||||
"payment_type": "po.payment_type",
|
||||
"order_status": "po.status",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT po.id,
|
||||
po.out_trade_no,
|
||||
ual.user_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ual.source_user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
po.amount::double precision,
|
||||
po.pay_amount::double precision,
|
||||
ual.amount::double precision,
|
||||
po.payment_type,
|
||||
po.status,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateRebateRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateRebateRecord
|
||||
if err := rows.Scan(
|
||||
&item.OrderID,
|
||||
&item.OutTradeNo,
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.OrderAmount,
|
||||
&item.PayAmount,
|
||||
&item.RebateAmount,
|
||||
&item.PaymentType,
|
||||
&item.OrderStatus,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"u.email", "u.username", "u.id::text",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN users u ON u.id = ual.user_id
|
||||
WHERE ual.action = 'transfer'`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"user": "u.email",
|
||||
"amount": "ual.amount",
|
||||
"balance_after": "ual.balance_after",
|
||||
"available_quota_after": "ual.aff_quota_after",
|
||||
"frozen_quota_after": "ual.aff_frozen_quota_after",
|
||||
"history_quota_after": "ual.aff_history_quota_after",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ual.id,
|
||||
ual.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ual.amount::double precision,
|
||||
ual.balance_after::double precision,
|
||||
ual.aff_quota_after::double precision,
|
||||
ual.aff_frozen_quota_after::double precision,
|
||||
ual.aff_history_quota_after::double precision,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateTransferRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateTransferRecord
|
||||
var balanceAfter sql.NullFloat64
|
||||
var availableQuotaAfter sql.NullFloat64
|
||||
var frozenQuotaAfter sql.NullFloat64
|
||||
var historyQuotaAfter sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&item.LedgerID,
|
||||
&item.UserID,
|
||||
&item.UserEmail,
|
||||
&item.Username,
|
||||
&item.Amount,
|
||||
&balanceAfter,
|
||||
&availableQuotaAfter,
|
||||
&frozenQuotaAfter,
|
||||
&historyQuotaAfter,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
|
||||
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
|
||||
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
|
||||
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
|
||||
item.SnapshotAvailable = balanceAfter.Valid &&
|
||||
availableQuotaAfter.Valid &&
|
||||
frozenQuotaAfter.Valid &&
|
||||
historyQuotaAfter.Valid
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var overview service.AffiliateUserOverview
|
||||
var customRate float64
|
||||
var hasCustomRate bool
|
||||
if err := rows.Scan(
|
||||
&overview.UserID,
|
||||
&overview.Email,
|
||||
&overview.Username,
|
||||
&overview.AffCode,
|
||||
&customRate,
|
||||
&hasCustomRate,
|
||||
&overview.InvitedCount,
|
||||
&overview.RebatedInviteeCount,
|
||||
&overview.AvailableQuota,
|
||||
&overview.HistoryQuota,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasCustomRate {
|
||||
overview.RebateRatePercent = customRate
|
||||
overview.RebateRateCustom = true
|
||||
}
|
||||
return &overview, rows.Err()
|
||||
}
|
||||
|
||||
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
|
||||
clauses := make([]string, 0, 3)
|
||||
args := make([]any, 0, 3)
|
||||
if filter.StartAt != nil {
|
||||
args = append(args, *filter.StartAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
|
||||
}
|
||||
if filter.EndAt != nil {
|
||||
args = append(args, *filter.EndAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
|
||||
}
|
||||
search := strings.TrimSpace(filter.Search)
|
||||
if search != "" && len(searchColumns) > 0 {
|
||||
args = append(args, "%"+strings.ToLower(search)+"%")
|
||||
parts := make([]string, 0, len(searchColumns))
|
||||
for _, col := range searchColumns {
|
||||
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
|
||||
}
|
||||
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
|
||||
}
|
||||
if len(clauses) == 0 {
|
||||
return "", args
|
||||
}
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
|
||||
column := sortColumns[filter.SortBy]
|
||||
if column == "" {
|
||||
column = fallbackColumn
|
||||
}
|
||||
direction := "DESC"
|
||||
if !filter.SortDesc {
|
||||
direction = "ASC"
|
||||
}
|
||||
return "ORDER BY " + column + " " + direction + " NULLS LAST"
|
||||
}
|
||||
|
||||
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
return 0, rows.Err()
|
||||
}
|
||||
var total int64
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
@@ -294,9 +782,12 @@ func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, us
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
aff_code_custom,
|
||||
aff_rebate_rate_percent,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
@@ -315,12 +806,16 @@ WHERE user_id = $1`, userID)
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
var rebateRate sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&out.AffCodeCustom,
|
||||
&rebateRate,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
@@ -330,6 +825,10 @@ WHERE user_id = $1`, userID)
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
if rebateRate.Valid {
|
||||
v := rebateRate.Float64
|
||||
out.AffRebateRatePercent = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
@@ -337,9 +836,12 @@ func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
aff_code_custom,
|
||||
aff_rebate_rate_percent,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
@@ -360,12 +862,16 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
var rebateRate sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&out.AffCodeCustom,
|
||||
&rebateRate,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
@@ -375,6 +881,10 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
if rebateRate.Valid {
|
||||
v := rebateRate.Float64
|
||||
out.AffRebateRatePercent = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
@@ -400,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
type affiliateTransferSnapshot struct {
|
||||
BalanceAfter float64
|
||||
AvailableQuotaAfter float64
|
||||
FrozenQuotaAfter float64
|
||||
HistoryQuotaAfter float64
|
||||
}
|
||||
|
||||
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT u.balance::double precision,
|
||||
ua.aff_quota::double precision,
|
||||
ua.aff_frozen_quota::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM users u
|
||||
JOIN user_affiliates ua ON ua.user_id = u.id
|
||||
WHERE u.id = $1
|
||||
LIMIT 1`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var snapshot affiliateTransferSnapshot
|
||||
if err := rows.Scan(
|
||||
&snapshot.BalanceAfter,
|
||||
&snapshot.AvailableQuotaAfter,
|
||||
&snapshot.FrozenQuotaAfter,
|
||||
&snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &snapshot, rows.Err()
|
||||
}
|
||||
|
||||
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
return &v.Float64
|
||||
}
|
||||
|
||||
func generateAffiliateCode() (string, error) {
|
||||
buf := make([]byte, affiliateCodeLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
@@ -418,3 +976,236 @@ func isAffiliateUniqueViolation(err error) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
|
||||
// 唯一性冲突返回 ErrAffiliateCodeTaken。
|
||||
func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
|
||||
if userID <= 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
code := strings.ToUpper(strings.TrimSpace(newCode))
|
||||
if code == "" {
|
||||
return service.ErrAffiliateCodeInvalid
|
||||
}
|
||||
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_code = $1,
|
||||
aff_code_custom = true,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, code, userID)
|
||||
if err != nil {
|
||||
if isAffiliateUniqueViolation(err) {
|
||||
return service.ErrAffiliateCodeTaken
|
||||
}
|
||||
return fmt.Errorf("update aff_code: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
|
||||
func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
|
||||
if userID <= 0 {
|
||||
return "", service.ErrUserNotFound
|
||||
}
|
||||
var newCode string
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < affiliateCodeMaxAttempts; i++ {
|
||||
candidate, codeErr := generateAffiliateCode()
|
||||
if codeErr != nil {
|
||||
return codeErr
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_code = $1,
|
||||
aff_code_custom = false,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, candidate, userID)
|
||||
if err != nil {
|
||||
if isAffiliateUniqueViolation(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("reset aff_code: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
newCode = candidate
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reset aff_code: exhausted attempts")
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return newCode, nil
|
||||
}
|
||||
|
||||
// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
|
||||
func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
|
||||
if userID <= 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
// nullableArg lets us use a single UPDATE for both "set value" and
|
||||
// "clear" cases — database/sql converts nil interface{} to SQL NULL.
|
||||
res, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_rebate_rate_percent = $1,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, nullableArg(ratePercent), userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。
|
||||
func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
|
||||
if len(userIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
for _, uid := range userIDs {
|
||||
if uid <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_rebate_rate_percent = $1,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
|
||||
// binding: nil pointer → SQL NULL, non-nil → the float value.
|
||||
func nullableArg(v *float64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func nullableInt64Arg(v *int64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
||||
//
|
||||
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
||||
// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
|
||||
// 这避免了为两种情况维护两份 SQL 模板。
|
||||
func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
|
||||
page := filter.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := filter.PageSize
|
||||
if pageSize <= 0 || pageSize > 200 {
|
||||
pageSize = 20
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
|
||||
|
||||
const baseFrom = `
|
||||
FROM user_affiliates ua
|
||||
JOIN users u ON u.id = ua.user_id
|
||||
WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
|
||||
AND (u.email ILIKE $1 OR u.username ILIKE $1)`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
|
||||
}
|
||||
|
||||
listQuery := `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.aff_code,
|
||||
ua.aff_code_custom,
|
||||
ua.aff_rebate_rate_percent,
|
||||
ua.aff_count` + baseFrom + `
|
||||
ORDER BY ua.updated_at DESC
|
||||
LIMIT $2 OFFSET $3`
|
||||
|
||||
rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
entries := make([]service.AffiliateAdminEntry, 0)
|
||||
for rows.Next() {
|
||||
var e service.AffiliateAdminEntry
|
||||
var rebate sql.NullFloat64
|
||||
if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
|
||||
&e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if rebate.Valid {
|
||||
v := rebate.Float64
|
||||
e.AffRebateRatePercent = &v
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return entries, total, nil
|
||||
}
|
||||
|
||||
// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
|
||||
func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
var v int64
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
||||
ledgerCount := querySingleInt(t, txCtx, client,
|
||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||
require.Equal(t, 1, ledgerCount)
|
||||
|
||||
rows, err := client.QueryContext(txCtx, `
|
||||
SELECT amount::double precision,
|
||||
balance_after::double precision,
|
||||
aff_quota_after::double precision,
|
||||
aff_frozen_quota_after::double precision,
|
||||
aff_history_quota_after::double precision
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1 AND action = 'transfer'
|
||||
LIMIT 1`, u.ID)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
require.True(t, rows.Next(), "expected transfer ledger")
|
||||
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
|
||||
require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter))
|
||||
require.InDelta(t, 12.34, amount, 1e-9)
|
||||
require.InDelta(t, 17.84, balanceAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, quotaAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, frozenAfter, 1e-9)
|
||||
require.InDelta(t, 12.34, historyAfter, 1e-9)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||
@@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, bound, "invitee must bind to inviter")
|
||||
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||
|
||||
@@ -182,3 +202,218 @@ VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
|
||||
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
|
||||
require.InDelta(t, 3.21, persistedBalance, 1e-9)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminCustomCode covers the success path of admin
|
||||
// invite-code rewrite + reset within a shared test transaction:
|
||||
// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
|
||||
// - the old code can no longer be found
|
||||
// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
|
||||
//
|
||||
// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
|
||||
// test because a unique-violation aborts the surrounding Postgres tx, which
|
||||
// would poison subsequent assertions in the same transaction.
|
||||
func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
|
||||
originalCode := original.AffCode
|
||||
|
||||
// Rewrite to a custom code
|
||||
customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
|
||||
|
||||
updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, customCode, updated.AffCode)
|
||||
require.True(t, updated.AffCodeCustom)
|
||||
|
||||
// Lookup by new custom code finds the user
|
||||
byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u.ID, byCode.UserID)
|
||||
|
||||
// Old system code should no longer match
|
||||
_, err = repo.GetAffiliateByCode(txCtx, originalCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
|
||||
|
||||
// Reset back to a fresh system code, clears custom flag
|
||||
newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, customCode, newSysCode)
|
||||
|
||||
reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newSysCode, reset.AffCode)
|
||||
require.False(t, reset.AffCodeCustom)
|
||||
|
||||
// The old custom code is now free again
|
||||
_, err = repo.GetAffiliateByCode(txCtx, customCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
|
||||
// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
|
||||
// this test must be the only assertion and run in its own tx — production
|
||||
// callers each have their own outer tx, so this matches real behavior.
|
||||
func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
taker := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
requester := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
|
||||
takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
|
||||
|
||||
// Now requester tries to grab the same code → conflict.
|
||||
err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
|
||||
// set/clear and the Batch variant including NULL semantics.
|
||||
func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u1 := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
u2 := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// Set exclusive rate for u1
|
||||
rate := 42.5
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
|
||||
|
||||
got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.AffRebateRatePercent)
|
||||
require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
|
||||
|
||||
// Clear exclusive rate
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
|
||||
cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, cleared.AffRebateRatePercent)
|
||||
|
||||
// Batch set both users
|
||||
batchRate := 15.0
|
||||
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
|
||||
|
||||
for _, uid := range []int64{u1.ID, u2.ID} {
|
||||
v, err := repo.EnsureUserAffiliate(txCtx, uid)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, v.AffRebateRatePercent)
|
||||
require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
|
||||
}
|
||||
|
||||
// Batch clear
|
||||
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
|
||||
for _, uid := range []int64{u1.ID, u2.ID} {
|
||||
v, err := repo.EnsureUserAffiliate(txCtx, uid)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, v.AffRebateRatePercent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
|
||||
// only includes users with at least one override applied.
|
||||
func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
// User without any custom config — should NOT appear in the list.
|
||||
plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
|
||||
uPlain := mustCreateUser(t, client, &service.User{
|
||||
Email: plainEmail, PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
_, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User with a custom code — should appear.
|
||||
uCode := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
|
||||
|
||||
// User with only an exclusive rate — should appear.
|
||||
uRate := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser, Status: service.StatusActive,
|
||||
})
|
||||
r := 33.3
|
||||
require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
|
||||
|
||||
entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
|
||||
Page: 1, PageSize: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a quick lookup to assert per-user attributes (other tests may have
|
||||
// inserted custom rows in the same DB; we only care about our 3).
|
||||
byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
|
||||
for _, e := range entries {
|
||||
byUserID[e.UserID] = e
|
||||
}
|
||||
|
||||
require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
|
||||
|
||||
codeEntry, ok := byUserID[uCode.ID]
|
||||
require.True(t, ok, "custom-code user missing from list")
|
||||
require.True(t, codeEntry.AffCodeCustom)
|
||||
require.Nil(t, codeEntry.AffRebateRatePercent)
|
||||
|
||||
rateEntry, ok := byUserID[uRate.ID]
|
||||
require.True(t, ok, "custom-rate user missing from list")
|
||||
require.False(t, rateEntry.AffCodeCustom)
|
||||
require.NotNil(t, rateEntry.AffRebateRatePercent)
|
||||
require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
|
||||
|
||||
require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
|
||||
}
|
||||
|
||||
28
backend/internal/repository/affiliate_repo_test.go
Normal file
28
backend/internal/repository/affiliate_repo_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
|
||||
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
|
||||
|
||||
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
|
||||
require.Contains(t, query, "frozen_until <= NOW()")
|
||||
}
|
||||
|
||||
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
|
||||
source, err := os.ReadFile("affiliate_repo.go")
|
||||
require.NoError(t, err)
|
||||
content := string(source)
|
||||
|
||||
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
|
||||
require.Contains(t, content, "ual.amount::double precision")
|
||||
require.Contains(t, content, "ual.balance_after::double precision")
|
||||
require.NotContains(t, content, "parseAffiliateRebateAmount")
|
||||
require.NotContains(t, content, `"current_balance": "u.balance"`)
|
||||
}
|
||||
@@ -24,6 +24,49 @@ const (
|
||||
|
||||
defaultSchedulerSnapshotMGetChunkSize = 128
|
||||
defaultSchedulerSnapshotWriteChunkSize = 256
|
||||
|
||||
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
|
||||
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
|
||||
snapshotGraceTTLSeconds = 60
|
||||
)
|
||||
|
||||
var (
|
||||
// activateSnapshotScript 原子 CAS 切换快照版本。
|
||||
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
|
||||
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
|
||||
//
|
||||
// KEYS[1] = activeKey (sched:active:{bucket})
|
||||
// KEYS[2] = readyKey (sched:ready:{bucket})
|
||||
// KEYS[3] = bucketSetKey (sched:buckets)
|
||||
// KEYS[4] = snapshotKey (新写入的快照 key)
|
||||
// ARGV[1] = 新版本号字符串
|
||||
// ARGV[2] = bucket 字符串 (用于 SADD)
|
||||
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
|
||||
// ARGV[4] = 宽限期 TTL 秒数
|
||||
//
|
||||
// 返回 1 = 已激活, 0 = 版本过旧未激活
|
||||
activateSnapshotScript = redis.NewScript(`
|
||||
local currentActive = redis.call('GET', KEYS[1])
|
||||
local newVersion = tonumber(ARGV[1])
|
||||
|
||||
if currentActive ~= false then
|
||||
local curVersion = tonumber(currentActive)
|
||||
if curVersion and newVersion < curVersion then
|
||||
redis.call('DEL', KEYS[4])
|
||||
return 0
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('SET', KEYS[1], ARGV[1])
|
||||
redis.call('SET', KEYS[2], '1')
|
||||
redis.call('SADD', KEYS[3], ARGV[2])
|
||||
|
||||
if currentActive ~= false and currentActive ~= ARGV[1] then
|
||||
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type schedulerCache struct {
|
||||
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||
|
||||
// Phase 1: 分配新版本号并写入快照数据。
|
||||
// INCR 保证每个调用方获得唯一递增版本号。
|
||||
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
|
||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||
if err != nil {
|
||||
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
if len(accounts) > 0 {
|
||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||
members := make([]redis.Z, 0, len(accounts))
|
||||
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
Member: strconv.FormatInt(account.ID, 10),
|
||||
})
|
||||
}
|
||||
pipe := c.rdb.Pipeline()
|
||||
for start := 0; start < len(members); start += c.writeChunkSize {
|
||||
end := start + c.writeChunkSize
|
||||
if end > len(members) {
|
||||
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
}
|
||||
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
|
||||
}
|
||||
} else {
|
||||
pipe.Del(ctx, snapshotKey)
|
||||
}
|
||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if oldActive != "" && oldActive != versionStr {
|
||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||
// Phase 2: 原子 CAS 激活版本。
|
||||
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
|
||||
// 防止并发写入导致版本回滚。
|
||||
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||
|
||||
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
|
||||
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
|
||||
|
||||
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
|
||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||
if err != nil {
|
||||
@@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
|
||||
SessionWindowStart: account.SessionWindowStart,
|
||||
SessionWindowEnd: account.SessionWindowEnd,
|
||||
SessionWindowStatus: account.SessionWindowStatus,
|
||||
AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
|
||||
GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
|
||||
Credentials: filterSchedulerCredentials(account.Credentials),
|
||||
Extra: filterSchedulerExtra(account.Extra),
|
||||
}
|
||||
}
|
||||
|
||||
func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
|
||||
if len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]service.AccountGroup, 0, len(accountGroups))
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, service.AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
|
||||
if len(groupIDs) == 0 && len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
|
||||
filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[ag.GroupID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[ag.GroupID] = struct{}{}
|
||||
filtered = append(filtered, ag.GroupID)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
|
||||
if len(credentials) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
SessionWindowStart: &now,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
SessionWindowStatus: "active",
|
||||
GroupIDs: []int64{bucket.GroupID},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 101,
|
||||
GroupID: bucket.GroupID,
|
||||
Priority: 5,
|
||||
Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
|
||||
@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
require.Equal(t, 4, got.GetMaxSessions())
|
||||
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 1)
|
||||
require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
|
||||
full, err := cache.GetAccount(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, full)
|
||||
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
|
||||
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
|
||||
require.Len(t, full.AccountGroups, 1)
|
||||
require.NotNil(t, full.AccountGroups[0].Group)
|
||||
}
|
||||
|
||||
@@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
|
||||
require.Equal(t, true, got.Extra["mixed_scheduling"])
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
}
|
||||
|
||||
func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
|
||||
account := service.Account{
|
||||
ID: 42,
|
||||
Platform: service.PlatformAnthropic,
|
||||
GroupIDs: []int64{7, 9, 7, 0},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 7,
|
||||
Priority: 2,
|
||||
Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
|
||||
Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 11,
|
||||
Priority: 3,
|
||||
Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 0,
|
||||
Priority: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := buildSchedulerMetadataAccount(account)
|
||||
|
||||
require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 2)
|
||||
require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
|
||||
require.Equal(t, 2, got.AccountGroups[0].Priority)
|
||||
require.Nil(t, got.AccountGroups[0].Account)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
|
||||
require.Nil(t, got.Groups)
|
||||
}
|
||||
|
||||
@@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@@ -737,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
@@ -745,6 +749,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": true,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": true,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": [],
|
||||
"payment_enabled": false,
|
||||
@@ -775,6 +789,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": false,
|
||||
"wechat_connect_app_id": "",
|
||||
"wechat_connect_app_secret_configured": false,
|
||||
@@ -897,6 +912,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"default_concurrency": 0,
|
||||
"default_balance": 0,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@@ -917,12 +935,23 @@ func TestAPIContracts(t *testing.T) {
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
"payment_visible_method_alipay_source": "",
|
||||
"payment_visible_method_wxpay_source": "",
|
||||
"payment_visible_method_alipay_enabled": false,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": false,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"payment_enabled": false,
|
||||
"payment_min_amount": 0,
|
||||
"payment_max_amount": 0,
|
||||
@@ -951,6 +980,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": true,
|
||||
"wechat_connect_app_id": "wx-open-config",
|
||||
"wechat_connect_app_secret_configured": true,
|
||||
|
||||
@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 渠道监控
|
||||
registerChannelMonitorRoutes(admin, h)
|
||||
|
||||
// 邀请返利(专属用户管理)
|
||||
registerAffiliateRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,3 +597,23 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
|
||||
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
affiliates := admin.Group("/affiliates")
|
||||
{
|
||||
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
|
||||
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
|
||||
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
|
||||
|
||||
users := affiliates.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.Affiliate.ListUsers)
|
||||
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
||||
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
||||
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
|
||||
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
||||
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -64,6 +65,7 @@ func isOpenAIImageModel(model string) bool {
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
@@ -74,6 +76,7 @@ type AccountTestService struct {
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
@@ -82,6 +85,7 @@ func NewAccountTestService(
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
@@ -210,6 +214,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if account.IsBedrock() {
|
||||
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
@@ -313,6 +320,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
|
||||
testModelID = mappedModel
|
||||
} else {
|
||||
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
|
||||
}
|
||||
|
||||
if s.claudeTokenProvider == nil {
|
||||
return s.sendErrorAndEnd(c, "Claude token provider not configured")
|
||||
}
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
|
||||
}
|
||||
|
||||
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errMsg)
|
||||
}
|
||||
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
@@ -480,7 +555,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
// 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。
|
||||
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
|
||||
// TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。
|
||||
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.sendErrorAndEnd(c,
|
||||
"账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+
|
||||
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
|
||||
)
|
||||
}
|
||||
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -711,8 +795,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
testModelID = geminicli.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
// For static upstream credentials with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -740,6 +824,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeServiceAccount:
|
||||
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -893,6 +979,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
||||
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service account access token: %w", err)
|
||||
}
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
|
||||
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
|
||||
var inner map[string]any
|
||||
@@ -1145,13 +1252,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
seenCompleted := false
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
if seenCompleted {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
@@ -1163,8 +1274,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
|
||||
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
if seenCompleted {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
@@ -1180,9 +1294,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
case "response.completed", "response.done":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "response.failed":
|
||||
errorMsg := "OpenAI response failed"
|
||||
if responseData, ok := data["response"].(map[string]any); ok {
|
||||
if errData, ok := responseData["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok && msg != "" {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
@@ -1210,7 +1334,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
|
||||
apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
|
||||
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 54,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||
}
|
||||
|
||||
@@ -125,6 +125,31 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
|
||||
|
||||
`))
|
||||
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 90,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, recorder.Body.String(), "response.completed")
|
||||
require.NotContains(t, recorder.Body.String(), `"success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
86
backend/internal/service/admin_balance_history_test.go
Normal file
86
backend/internal/service/admin_balance_history_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
older := now.Add(-2 * time.Hour)
|
||||
newer := now.Add(time.Hour)
|
||||
|
||||
usedBy := int64(10)
|
||||
redeemCodes := []RedeemCode{
|
||||
{
|
||||
ID: 1,
|
||||
Type: RedeemTypeBalance,
|
||||
Value: 8,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &now,
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Type: RedeemTypeConcurrency,
|
||||
Value: 1,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &older,
|
||||
CreatedAt: older,
|
||||
},
|
||||
}
|
||||
affiliateCodes := []RedeemCode{
|
||||
{
|
||||
ID: -20,
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: 3.5,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &newer,
|
||||
CreatedAt: newer,
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 2,
|
||||
})
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
|
||||
require.Equal(t, RedeemTypeBalance, got[1].Type)
|
||||
}
|
||||
|
||||
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
usedBy := int64(10)
|
||||
at := func(hours int) *time.Time {
|
||||
v := base.Add(time.Duration(hours) * time.Hour)
|
||||
return &v
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(
|
||||
[]RedeemCode{
|
||||
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
|
||||
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
|
||||
},
|
||||
[]RedeemCode{
|
||||
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
|
||||
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
|
||||
},
|
||||
pagination.PaginationParams{Page: 2, PageSize: 2},
|
||||
)
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
|
||||
require.Equal(t, int64(-4), got[1].ID)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -58,6 +60,7 @@ type AdminService interface {
|
||||
|
||||
// API Key management (admin)
|
||||
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
|
||||
AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
|
||||
|
||||
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
|
||||
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
|
||||
@@ -291,6 +294,7 @@ type UpdateAccountInput struct {
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Filters *BulkUpdateAccountFilters
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
@@ -307,6 +311,15 @@ type BulkUpdateAccountsInput struct {
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string
|
||||
Type string
|
||||
Status string
|
||||
Group string
|
||||
Search string
|
||||
PrivacyMode string
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
type BulkUpdateAccountResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
@@ -961,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
if codeType == RedeemTypeAffiliateBalance {
|
||||
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
if codeType == "" {
|
||||
return s.getAllUserBalanceHistory(ctx, userID, params)
|
||||
}
|
||||
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
total := result.Total
|
||||
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, result.Total, totalRecharged, nil
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
|
||||
needed := params.Offset() + params.Limit()
|
||||
if needed < params.Limit() {
|
||||
needed = params.Limit()
|
||||
}
|
||||
|
||||
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
|
||||
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result != nil {
|
||||
total = result.Total
|
||||
}
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = currentTotal
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
rows, err := s.entClient.QueryContext(ctx, `
|
||||
SELECT id,
|
||||
amount::double precision,
|
||||
created_at
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'
|
||||
ORDER BY created_at DESC, id DESC
|
||||
OFFSET $2
|
||||
LIMIT $3`, userID, params.Offset(), params.Limit())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
codes := make([]RedeemCode, 0, params.Limit())
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var amount float64
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
usedBy := userID
|
||||
usedAt := createdAt
|
||||
codes = append(codes, RedeemCode{
|
||||
ID: -id,
|
||||
Code: fmt.Sprintf("AFF-%d", id),
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: amount,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &usedAt,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return codes, total, nil
|
||||
}
|
||||
|
||||
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'`, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var total sql.NullInt64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !total.Valid {
|
||||
return 0, nil
|
||||
}
|
||||
return total.Int64, nil
|
||||
}
|
||||
|
||||
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
|
||||
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
|
||||
sort.SliceStable(combined, func(i, j int) bool {
|
||||
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
|
||||
})
|
||||
offset := params.Offset()
|
||||
if offset >= len(combined) {
|
||||
return []RedeemCode{}
|
||||
}
|
||||
end := offset + params.Limit()
|
||||
if end > len(combined) {
|
||||
end = len(combined)
|
||||
}
|
||||
return combined[offset:end]
|
||||
}
|
||||
|
||||
func redeemCodeHistoryTime(code RedeemCode) time.Time {
|
||||
if code.UsedAt != nil {
|
||||
return *code.UsedAt
|
||||
}
|
||||
return code.CreatedAt
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
||||
@@ -1961,6 +2171,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
|
||||
func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apiKey.Usage5h = 0
|
||||
apiKey.Usage1d = 0
|
||||
apiKey.Usage7d = 0
|
||||
apiKey.Window5hStart = nil
|
||||
apiKey.Window1dStart = nil
|
||||
apiKey.Window7dStart = nil
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
if s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// ReplaceUserGroup 替换用户的专属分组
|
||||
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
|
||||
if oldGroupID == newGroupID {
|
||||
@@ -2286,6 +2520,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
if len(input.AccountIDs) == 0 && input.Filters != nil {
|
||||
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input.AccountIDs = accountIDs
|
||||
}
|
||||
|
||||
result := &BulkUpdateAccountsResult{
|
||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
@@ -2401,6 +2643,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
|
||||
if filters == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
groupID := int64(0)
|
||||
switch strings.TrimSpace(filters.Group) {
|
||||
case "":
|
||||
case "ungrouped":
|
||||
groupID = AccountListGroupUngrouped
|
||||
default:
|
||||
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group filter: %w", err)
|
||||
}
|
||||
groupID = parsedGroupID
|
||||
}
|
||||
|
||||
const pageSize = 500
|
||||
page := 1
|
||||
accountIDs := make([]int64, 0, pageSize)
|
||||
|
||||
for {
|
||||
accounts, total, err := s.ListAccounts(
|
||||
ctx,
|
||||
page,
|
||||
pageSize,
|
||||
filters.Platform,
|
||||
filters.Type,
|
||||
filters.Status,
|
||||
filters.Search,
|
||||
groupID,
|
||||
filters.PrivacyMode,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountIDs = append(accountIDs, account.ID)
|
||||
}
|
||||
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
|
||||
return accountIDs, nil
|
||||
}
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
|
||||
@@ -5,8 +5,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
|
||||
getByIDCalled []int64
|
||||
listByGroupData map[int64][]Account
|
||||
listByGroupErr map[int64]error
|
||||
listData []Account
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
listCalled bool
|
||||
lastListParams pagination.PaginationParams
|
||||
lastListFilters struct {
|
||||
platform string
|
||||
accountType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
privacyMode string
|
||||
}
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listCalled = true
|
||||
s.lastListParams = params
|
||||
s.lastListFilters.platform = platform
|
||||
s.lastListFilters.accountType = accountType
|
||||
s.lastListFilters.status = status
|
||||
s.lastListFilters.search = search
|
||||
s.lastListFilters.groupID = groupID
|
||||
s.lastListFilters.privacyMode = privacyMode
|
||||
if s.listErr != nil {
|
||||
return nil, nil, s.listErr
|
||||
}
|
||||
if s.listResult != nil {
|
||||
return s.listData, s.listResult, nil
|
||||
}
|
||||
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
|
||||
// No BindGroups should have been called since the check runs before any write.
|
||||
require.Empty(t, repo.bindGroupsCalls)
|
||||
}
|
||||
|
||||
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
listData: []Account{
|
||||
{ID: 7},
|
||||
{ID: 11},
|
||||
},
|
||||
listResult: &pagination.PaginationResult{Total: 2},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
|
||||
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
|
||||
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
|
||||
|
||||
filtersValue := reflect.New(filtersField.Type().Elem())
|
||||
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
|
||||
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
|
||||
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
|
||||
filtersValue.Elem().FieldByName("Group").SetString("12")
|
||||
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
|
||||
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
|
||||
filtersField.Set(filtersValue)
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
|
||||
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
|
||||
require.Equal(t, StatusActive, repo.lastListFilters.status)
|
||||
require.Equal(t, "bulk-target", repo.lastListFilters.search)
|
||||
require.Equal(t, int64(12), repo.lastListFilters.groupID)
|
||||
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
|
||||
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -15,28 +14,39 @@ import (
|
||||
var (
|
||||
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
|
||||
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
|
||||
ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
|
||||
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
|
||||
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
|
||||
)
|
||||
|
||||
const (
|
||||
affiliateInviteesLimit = 100
|
||||
// affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength.
|
||||
affiliateCodeFormatLength = 12
|
||||
// AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
|
||||
// 12-char codes and admin-customized codes (e.g. "VIP2026").
|
||||
AffiliateCodeMinLength = 4
|
||||
AffiliateCodeMaxLength = 32
|
||||
)
|
||||
|
||||
// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used
|
||||
// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9).
|
||||
// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
|
||||
// All input passes through strings.ToUpper before validation, so lowercase from
|
||||
// users is normalized — admins may supply mixed case in their UI.
|
||||
var affiliateCodeValidChar = func() [256]bool {
|
||||
var tbl [256]bool
|
||||
for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") {
|
||||
for c := byte('A'); c <= 'Z'; c++ {
|
||||
tbl[c] = true
|
||||
}
|
||||
for c := byte('0'); c <= '9'; c++ {
|
||||
tbl[c] = true
|
||||
}
|
||||
tbl['_'] = true
|
||||
tbl['-'] = true
|
||||
return tbl
|
||||
}()
|
||||
|
||||
// isValidAffiliateCodeFormat validates code format for both binding (user input)
|
||||
// and admin updates. Caller is expected to upper-case the input first.
|
||||
func isValidAffiliateCodeFormat(code string) bool {
|
||||
if len(code) != affiliateCodeFormatLength {
|
||||
if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(code); i++ {
|
||||
@@ -48,58 +58,176 @@ func isValidAffiliateCodeFormat(code string) bool {
|
||||
}
|
||||
|
||||
type AffiliateSummary struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
AffCodeCustom bool `json:"aff_code_custom"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AffiliateInvitee struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
}
|
||||
|
||||
type AffiliateDetail struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
Invitees []AffiliateInvitee `json:"invitees"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AffCode string `json:"aff_code"`
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
|
||||
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
|
||||
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
|
||||
EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
|
||||
Invitees []AffiliateInvitee `json:"invitees"`
|
||||
}
|
||||
|
||||
type AffiliateRepository interface {
|
||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
|
||||
|
||||
// 管理端:用户级专属配置
|
||||
UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
|
||||
ResetUserAffCode(ctx context.Context, userID int64) (string, error)
|
||||
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
||||
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
||||
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
||||
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
|
||||
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
|
||||
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
|
||||
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
|
||||
}
|
||||
|
||||
// AffiliateAdminFilter 列表筛选条件
|
||||
type AffiliateAdminFilter struct {
|
||||
Search string
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// AffiliateAdminEntry 专属用户列表条目
|
||||
type AffiliateAdminEntry struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
AffCodeCustom bool `json:"aff_code_custom"`
|
||||
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
}
|
||||
|
||||
type AffiliateRecordFilter struct {
|
||||
Search string
|
||||
Page int
|
||||
PageSize int
|
||||
StartAt *time.Time
|
||||
EndAt *time.Time
|
||||
SortBy string
|
||||
SortDesc bool
|
||||
}
|
||||
|
||||
type AffiliateInviteRecord struct {
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateRebateRecord struct {
|
||||
OrderID int64 `json:"order_id"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
OrderAmount float64 `json:"order_amount"`
|
||||
PayAmount float64 `json:"pay_amount"`
|
||||
RebateAmount float64 `json:"rebate_amount"`
|
||||
PaymentType string `json:"payment_type"`
|
||||
OrderStatus string `json:"order_status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateTransferRecord struct {
|
||||
LedgerID int64 `json:"ledger_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
Username string `json:"username"`
|
||||
Amount float64 `json:"amount"`
|
||||
BalanceAfter *float64 `json:"balance_after,omitempty"`
|
||||
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
|
||||
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
|
||||
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
|
||||
SnapshotAvailable bool `json:"snapshot_available"`
|
||||
CurrentBalance float64 `json:"-"`
|
||||
RemainingQuota float64 `json:"-"`
|
||||
FrozenQuota float64 `json:"-"`
|
||||
HistoryQuota float64 `json:"-"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateUserOverview struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
RebateRatePercent float64 `json:"rebate_rate_percent"`
|
||||
RebateRateCustom bool `json:"-"`
|
||||
InvitedCount int `json:"invited_count"`
|
||||
RebatedInviteeCount int `json:"rebated_invitee_count"`
|
||||
AvailableQuota float64 `json:"available_quota"`
|
||||
HistoryQuota float64 `json:"history_quota"`
|
||||
}
|
||||
|
||||
type AffiliateService struct {
|
||||
repo AffiliateRepository
|
||||
settingRepo SettingRepository
|
||||
settingService *SettingService
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
func NewAffiliateService(repo AffiliateRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
|
||||
func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
|
||||
return &AffiliateService{
|
||||
repo: repo,
|
||||
settingRepo: settingRepo,
|
||||
settingService: settingService,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
|
||||
func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
|
||||
if s == nil || s.settingService == nil {
|
||||
return AffiliateEnabledDefault
|
||||
}
|
||||
return s.settingService.IsAffiliateEnabled(ctx)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
|
||||
if userID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||
@@ -111,6 +239,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64
|
||||
}
|
||||
|
||||
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
|
||||
// Lazy thaw: move any matured frozen quota to available before reading.
|
||||
if s != nil && s.repo != nil {
|
||||
// best-effort: thaw failure is non-fatal
|
||||
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
|
||||
}
|
||||
|
||||
summary, err := s.EnsureUserAffiliate(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -120,13 +254,15 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
|
||||
return nil, err
|
||||
}
|
||||
return &AffiliateDetail{
|
||||
UserID: summary.UserID,
|
||||
AffCode: summary.AffCode,
|
||||
InviterID: summary.InviterID,
|
||||
AffCount: summary.AffCount,
|
||||
AffQuota: summary.AffQuota,
|
||||
AffHistoryQuota: summary.AffHistoryQuota,
|
||||
Invitees: invitees,
|
||||
UserID: summary.UserID,
|
||||
AffCode: summary.AffCode,
|
||||
InviterID: summary.InviterID,
|
||||
AffCount: summary.AffCount,
|
||||
AffQuota: summary.AffQuota,
|
||||
AffFrozenQuota: summary.AffFrozenQuota,
|
||||
AffHistoryQuota: summary.AffHistoryQuota,
|
||||
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
|
||||
Invitees: invitees,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -135,12 +271,16 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
if !isValidAffiliateCodeFormat(code) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
// 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
|
||||
if !s.IsEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
if !isValidAffiliateCodeFormat(code) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
|
||||
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -172,12 +312,20 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
|
||||
return 0, nil
|
||||
}
|
||||
// 总开关关闭时,新充值不再产生返利
|
||||
if !s.IsEnabled(ctx) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
|
||||
if err != nil {
|
||||
@@ -187,17 +335,48 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
rebateRatePercent := s.loadAffiliateRebateRatePercent(ctx)
|
||||
// 加载邀请人 profile,优先使用专属比例(覆盖全局)
|
||||
inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// 有效期检查:超过返利有效期后不再产生返利
|
||||
if s.settingService != nil {
|
||||
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
|
||||
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
|
||||
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
|
||||
if rebate <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if _, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID); err != nil {
|
||||
return 0, err
|
||||
// 单人上限检查:精确截断到剩余额度
|
||||
if s.settingService != nil {
|
||||
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
|
||||
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if existing >= perInviteeCap {
|
||||
return 0, nil
|
||||
}
|
||||
if remaining := perInviteeCap - existing; rebate > remaining {
|
||||
rebate = roundTo(remaining, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
|
||||
var freezeHours int
|
||||
if s.settingService != nil {
|
||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -207,6 +386,28 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
||||
return rebate, nil
|
||||
}
|
||||
|
||||
// resolveRebateRatePercent returns the inviter's exclusive rate when set,
|
||||
// otherwise the global setting value (clamped to [Min, Max]).
|
||||
func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
|
||||
if inviter != nil && inviter.AffRebateRatePercent != nil {
|
||||
v := *inviter.AffRebateRatePercent
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
return s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
return clampAffiliateRebateRate(v)
|
||||
}
|
||||
return s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
|
||||
// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
|
||||
// returning the documented default when SettingService is unavailable.
|
||||
func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
|
||||
if s == nil || s.settingService == nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
return s.settingService.GetAffiliateRebateRatePercent(ctx)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
@@ -236,32 +437,6 @@ func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (s *AffiliateService) loadAffiliateRebateRatePercent(ctx context.Context) float64 {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
|
||||
if err != nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
|
||||
rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
|
||||
if err != nil {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
if math.IsNaN(rate) || math.IsInf(rate, 0) {
|
||||
return AffiliateRebateRateDefault
|
||||
}
|
||||
if rate < AffiliateRebateRateMin {
|
||||
return AffiliateRebateRateMin
|
||||
}
|
||||
if rate > AffiliateRebateRateMax {
|
||||
return AffiliateRebateRateMax
|
||||
}
|
||||
return rate
|
||||
}
|
||||
|
||||
func roundTo(v float64, scale int) float64 {
|
||||
factor := math.Pow10(scale)
|
||||
return math.Round(v*factor) / factor
|
||||
@@ -312,3 +487,138 @@ func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Admin: 专属配置管理
|
||||
// =========================
|
||||
|
||||
// validateExclusiveRate ensures a per-user override is finite and within
|
||||
// [Min, Max]. nil is always valid (means "clear / fall back to global").
|
||||
func validateExclusiveRate(ratePercent *float64) error {
|
||||
if ratePercent == nil {
|
||||
return nil
|
||||
}
|
||||
v := *ratePercent
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
|
||||
}
|
||||
if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
|
||||
return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
|
||||
func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
code := strings.ToUpper(strings.TrimSpace(rawCode))
|
||||
if !isValidAffiliateCodeFormat(code) {
|
||||
return ErrAffiliateCodeInvalid
|
||||
}
|
||||
return s.repo.UpdateUserAffCode(ctx, userID, code)
|
||||
}
|
||||
|
||||
// AdminResetUserAffCode 重置用户邀请码为系统随机码。
|
||||
func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ResetUserAffCode(ctx, userID)
|
||||
}
|
||||
|
||||
// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
|
||||
func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
if err := validateExclusiveRate(ratePercent); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
|
||||
}
|
||||
|
||||
// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
|
||||
func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
if err := validateExclusiveRate(ratePercent); err != nil {
|
||||
return err
|
||||
}
|
||||
cleaned := make([]int64, 0, len(userIDs))
|
||||
for _, uid := range userIDs {
|
||||
if uid > 0 {
|
||||
cleaned = append(cleaned, uid)
|
||||
}
|
||||
}
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
|
||||
}
|
||||
|
||||
// AdminListCustomUsers 列出有专属配置的用户。
|
||||
func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if overview != nil {
|
||||
if !overview.RebateRateCustom {
|
||||
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
|
||||
}
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 20
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
filter.Search = strings.TrimSpace(filter.Search)
|
||||
filter.SortBy = strings.TrimSpace(filter.SortBy)
|
||||
return filter
|
||||
}
|
||||
|
||||
@@ -4,51 +4,82 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type affiliateSettingRepoStub struct {
|
||||
value string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *affiliateSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, s.err }
|
||||
func (s *affiliateSettingRepoStub) GetValue(context.Context, string) (string, error) {
|
||||
if s.err != nil {
|
||||
return "", s.err
|
||||
}
|
||||
return s.value, nil
|
||||
}
|
||||
func (s *affiliateSettingRepoStub) Set(context.Context, string, string) error { return s.err }
|
||||
func (s *affiliateSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
func (s *affiliateSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
return s.err
|
||||
}
|
||||
func (s *affiliateSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
func (s *affiliateSettingRepoStub) Delete(context.Context, string) error { return s.err }
|
||||
|
||||
func TestAffiliateRebateRatePercentSemantics(t *testing.T) {
|
||||
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
|
||||
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
|
||||
// global rate, and that out-of-range exclusive rates are clamped silently.
|
||||
//
|
||||
// SettingService is left nil here so globalRebateRatePercent returns the
|
||||
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
|
||||
// fallback path without spinning up a settings stub.
|
||||
func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &AffiliateService{}
|
||||
|
||||
svc := &AffiliateService{settingRepo: &affiliateSettingRepoStub{value: "1"}}
|
||||
rate := svc.loadAffiliateRebateRatePercent(context.Background())
|
||||
require.Equal(t, 1.0, rate)
|
||||
// nil exclusive rate → falls back to global default (20%)
|
||||
require.InDelta(t, AffiliateRebateRateDefault,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
|
||||
|
||||
svc.settingRepo = &affiliateSettingRepoStub{value: "0.2"}
|
||||
rate = svc.loadAffiliateRebateRatePercent(context.Background())
|
||||
require.Equal(t, 0.2, rate)
|
||||
// exclusive rate set → overrides global
|
||||
rate := 50.0
|
||||
require.InDelta(t, 50.0,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
|
||||
|
||||
// exclusive rate 0 → returns 0 (no rebate, intentional)
|
||||
zero := 0.0
|
||||
require.InDelta(t, 0.0,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
|
||||
|
||||
// exclusive rate above max → clamped to Max
|
||||
tooHigh := 250.0
|
||||
require.InDelta(t, AffiliateRebateRateMax,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
|
||||
|
||||
// exclusive rate below min → clamped to Min
|
||||
tooLow := -5.0
|
||||
require.InDelta(t, AffiliateRebateRateMin,
|
||||
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
|
||||
}
|
||||
|
||||
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
|
||||
// safely handles a nil settingService dependency by returning the default
|
||||
// (off). This protects callers from nil-pointer crashes in misconfigured
|
||||
// environments.
|
||||
func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := &AffiliateService{}
|
||||
require.False(t, svc.IsEnabled(context.Background()))
|
||||
require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
|
||||
}
|
||||
|
||||
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
|
||||
// admin-facing rate setters: nil is always valid (clear), in-range values
|
||||
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
|
||||
func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, validateExclusiveRate(nil))
|
||||
|
||||
for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
|
||||
v := v
|
||||
require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
|
||||
}
|
||||
|
||||
for _, v := range []float64{-0.01, 100.01, -100, 200} {
|
||||
v := v
|
||||
require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
|
||||
}
|
||||
|
||||
nan := math.NaN()
|
||||
require.Error(t, validateExclusiveRate(&nan))
|
||||
posInf := math.Inf(1)
|
||||
require.Error(t, validateExclusiveRate(&posInf))
|
||||
negInf := math.Inf(-1)
|
||||
require.Error(t, validateExclusiveRate(&negInf))
|
||||
}
|
||||
|
||||
func TestMaskEmail(t *testing.T) {
|
||||
@@ -61,24 +92,33 @@ func TestMaskEmail(t *testing.T) {
|
||||
func TestIsValidAffiliateCodeFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 邀请码格式校验同时服务于:
|
||||
// 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
|
||||
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
|
||||
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"valid canonical", "ABCDEFGHJKLM", true},
|
||||
{"valid canonical 12-char", "ABCDEFGHJKLM", true},
|
||||
{"valid all digits 2-9", "234567892345", true},
|
||||
{"valid mixed", "A2B3C4D5E6F7", true},
|
||||
{"too short", "ABCDEFGHJKL", false},
|
||||
{"too long", "ABCDEFGHJKLMN", false},
|
||||
{"contains excluded letter I", "IBCDEFGHJKLM", false},
|
||||
{"contains excluded letter O", "OBCDEFGHJKLM", false},
|
||||
{"contains excluded digit 0", "0BCDEFGHJKLM", false},
|
||||
{"contains excluded digit 1", "1BCDEFGHJKLM", false},
|
||||
{"valid admin custom short", "VIP1", true},
|
||||
{"valid admin custom with hyphen", "NEW-USER", true},
|
||||
{"valid admin custom with underscore", "VIP_2026", true},
|
||||
{"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
|
||||
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
|
||||
{"letter I now allowed", "IBCDEFGHJKLM", true},
|
||||
{"letter O now allowed", "OBCDEFGHJKLM", true},
|
||||
{"digit 0 now allowed", "0BCDEFGHJKLM", true},
|
||||
{"digit 1 now allowed", "1BCDEFGHJKLM", true},
|
||||
{"too short (3 chars)", "ABC", false},
|
||||
{"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
|
||||
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
|
||||
{"empty", "", false},
|
||||
{"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset
|
||||
{"ascii punctuation", "ABCDEFGHJK.M", false},
|
||||
{"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
|
||||
{"ascii punctuation .", "ABCDEFGHJK.M", false},
|
||||
{"whitespace", "ABCDEFGHJK M", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
|
||||
@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
user *User,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
affiliateCode string,
|
||||
) error {
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
}
|
||||
} else {
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
@@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
}
|
||||
}
|
||||
|
||||
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
|
||||
// for an OAuth-registered user. Failures are logged but never block registration.
|
||||
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
|
||||
if s.affiliateService == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
|
||||
@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.Equal(t, existing.ID, user.ID)
|
||||
|
||||
@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
|
||||
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// API Key 限速缓存方法
|
||||
// ============================================
|
||||
|
||||
@@ -17,7 +17,7 @@ const (
|
||||
// ClaudeTokenCache token cache interface.
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||
return "", errors.New("not an anthropic oauth or service account")
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return p.getServiceAccountAccessToken(ctx, account)
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
return "", errors.New("not an anthropic oauth or service account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,15 @@ const (
|
||||
|
||||
// Affiliate rebate settings
|
||||
const (
|
||||
AffiliateRebateRateDefault = 20.0
|
||||
AffiliateRebateRateMin = 0.0
|
||||
AffiliateRebateRateMax = 100.0
|
||||
AffiliateRebateRateDefault = 20.0
|
||||
AffiliateRebateRateMin = 0.0
|
||||
AffiliateRebateRateMax = 100.0
|
||||
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
|
||||
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
|
||||
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
|
||||
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
|
||||
AffiliateRebateDurationDaysMax = 3650 // ~10 年
|
||||
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
@@ -35,19 +41,21 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeAffiliateBalance = "affiliate_balance"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
@@ -94,7 +102,11 @@ const (
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
|
||||
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
|
||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -296,6 +308,12 @@ const (
|
||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||
|
||||
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
|
||||
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
|
||||
// targets OpenAI's body-level service_tier field instead of Claude's
|
||||
// anthropic-beta header.
|
||||
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
|
||||
|
||||
// =========================
|
||||
// Claude Code Version Check
|
||||
// =========================
|
||||
@@ -319,6 +337,8 @@ const (
|
||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
|
||||
|
||||
// Balance Low Notification
|
||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer inbound-token")
|
||||
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
|
||||
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeServiceAccount,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "vertex-proj",
|
||||
"location": "us-east5",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
svc := &GatewayService{}
|
||||
req, err := svc.buildUpstreamRequest(
|
||||
context.Background(),
|
||||
c,
|
||||
account,
|
||||
body,
|
||||
"vertex-token",
|
||||
"service_account",
|
||||
"claude-sonnet-4-5@20250929",
|
||||
false,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
|
||||
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
|
||||
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
|
||||
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
|
||||
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
|
||||
|
||||
got := readRequestBodyForTest(t, req)
|
||||
require.Equal(t, "", gjson.GetBytes(got, "model").String())
|
||||
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
|
||||
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
|
||||
}
|
||||
|
||||
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
|
||||
t.Helper()
|
||||
require.NotNil(t, req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
return body
|
||||
}
|
||||
@@ -1,13 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type gatewayTTLSettingRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
if r == nil {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
v, ok := r.data[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
r.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
if v, ok := r.data[key]; ok {
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.data[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for key, value := range r.data {
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
|
||||
if r != nil {
|
||||
delete(r.data, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||
}
|
||||
|
||||
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
||||
|
||||
result := injectAnthropicCacheControlTTL1h(body)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||
|
||||
target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget5m, target)
|
||||
|
||||
account.Extra = map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "1h",
|
||||
}
|
||||
target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget1h, target)
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
|
||||
|
||||
repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
}
|
||||
|
||||
@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@@ -9,6 +9,11 @@ import (
|
||||
)
|
||||
|
||||
func TestIsClaudeCodeClient(t *testing.T) {
|
||||
// 合法的 legacy 格式 metadata.user_id(64位 hex + account uuid + session uuid)
|
||||
legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
|
||||
// 合法的 JSON 格式 metadata.user_id(2.1.78+ 版本)
|
||||
jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
@@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Claude Code client",
|
||||
name: "Claude Code client with legacy user_id",
|
||||
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
|
||||
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
||||
metadataUserID: legacyUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code without version suffix",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "session_abc",
|
||||
name: "Claude Code client with JSON user_id",
|
||||
userAgent: "claude-cli/2.1.92 (external, cli)",
|
||||
metadataUserID: jsonUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code case insensitive UA",
|
||||
userAgent: "Claude-CLI/2.0.0",
|
||||
metadataUserID: legacyUserID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
@@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) {
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent",
|
||||
name: "Claude CLI UA with invalid user_id format",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "fake-user-id-12345",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent with valid user_id",
|
||||
userAgent: "curl/7.68.0",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Empty user agent",
|
||||
userAgent: "",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Similar but not Claude CLI",
|
||||
userAgent: "claude-api/1.0.0",
|
||||
metadataUserID: "user123",
|
||||
metadataUserID: legacyUserID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Opencode spoofing UA with arbitrary user_id",
|
||||
userAgent: "claude-cli/2.1.92",
|
||||
metadataUserID: "session_abc",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -60,6 +62,11 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTTLTarget5m = "5m"
|
||||
cacheTTLTarget1h = "1h"
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -119,7 +126,7 @@ func openAIStreamEventIsTerminal(data string) bool {
|
||||
return true
|
||||
}
|
||||
switch gjson.Get(trimmed, "type").String() {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -329,7 +336,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@@ -652,15 +659,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
|
||||
uid := ParseMetadataUserID(parsed.MetadataUserID)
|
||||
if uid != nil && uid.SessionID != "" {
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "metadata_user_id",
|
||||
"session_id", uid.SessionID,
|
||||
"device_id", uid.DeviceID,
|
||||
"is_new_format", uid.IsNewFormat,
|
||||
)
|
||||
return uid.SessionID
|
||||
}
|
||||
slog.Info("sticky.hash_metadata_parse_failed",
|
||||
"metadata_user_id", parsed.MetadataUserID,
|
||||
"parsed_nil", uid == nil,
|
||||
)
|
||||
}
|
||||
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
hash := s.hashContent(cacheableContent)
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "cacheable_content",
|
||||
"hash", hash,
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
@@ -700,7 +723,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
}
|
||||
if combined.Len() > 0 {
|
||||
return s.hashContent(combined.String())
|
||||
hash := s.hashContent(combined.String())
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "message_content_fallback",
|
||||
"hash", hash,
|
||||
"content_len", combined.Len(),
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
return ""
|
||||
@@ -1404,14 +1433,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
var stickySource string
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
stickySource = "prefetch"
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
stickySource = "cache"
|
||||
}
|
||||
}
|
||||
|
||||
// [DEBUG-STICKY] 调度器入口日志
|
||||
slog.Info("sticky.scheduler_entry",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"session_hash", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"sticky_source", stickySource,
|
||||
"model", requestedModel,
|
||||
"load_batch", cfg.LoadBatchEnabled,
|
||||
"has_concurrency_svc", s.concurrencyService != nil,
|
||||
"excluded_count", len(excludedIDs),
|
||||
)
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
@@ -1587,6 +1631,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && stickyAccountID > 0 {
|
||||
slog.Debug("sticky.layer1_5_checking",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"in_routing_list", containsInt64(routingAccountIDs, stickyAccountID),
|
||||
"is_excluded", isExcluded(stickyAccountID),
|
||||
"in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
@@ -1610,6 +1661,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_hit",
|
||||
"account_id", stickyAccountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
@@ -1760,27 +1816,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||
if clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_clear",
|
||||
"account_id", accountID,
|
||||
"reason", "should_clear_sticky_session",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
// 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的
|
||||
// accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存
|
||||
// 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。
|
||||
platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed)
|
||||
modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)
|
||||
modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel)
|
||||
quotaOK := s.isAccountSchedulableForQuota(account)
|
||||
windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true)
|
||||
rpmOK := s.isAccountSchedulableForRPM(ctx, account, true)
|
||||
schedulable := s.isAccountSchedulableForSelection(account)
|
||||
|
||||
slog.Debug("sticky.layer1_5_no_routing_checks",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"clear_sticky", clearSticky,
|
||||
"schedulable", schedulable,
|
||||
"platform_ok", platformOK,
|
||||
"model_supported", modelSupported,
|
||||
"model_schedulable", modelSchedulable,
|
||||
"quota_ok", quotaOK,
|
||||
"window_cost_ok", windowCostOK,
|
||||
"rpm_ok", rpmOK,
|
||||
)
|
||||
|
||||
if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "session_limit",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_slot_busy",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@@ -1789,6 +1883,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "wait_plan",
|
||||
)
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
@@ -1797,12 +1896,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if !clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "gate_check_failed",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "account_not_in_map",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if len(routingAccountIDs) == 0 && sessionHash != "" {
|
||||
slog.Debug("sticky.layer1_5_no_routing_skip",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"reason", func() string {
|
||||
if stickyAccountID == 0 {
|
||||
return "no_sticky_binding"
|
||||
}
|
||||
return "sticky_account_excluded"
|
||||
}(),
|
||||
)
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
slog.Debug("sticky.layer2_fallback",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"reason", "sticky_not_used_falling_back_to_load_balance",
|
||||
"total_accounts", len(accounts),
|
||||
)
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -3597,7 +3726,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
|
||||
} else {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
@@ -3617,6 +3750,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||
case AccountTypeServiceAccount:
|
||||
if account.Platform != PlatformAnthropic {
|
||||
return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
|
||||
}
|
||||
if s.claudeTokenProvider == nil {
|
||||
return "", "", errors.New("claude token provider not configured")
|
||||
}
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "service_account", nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -3709,13 +3854,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
|
||||
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
|
||||
// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。
|
||||
// 判定条件:
|
||||
// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感)
|
||||
// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式)
|
||||
//
|
||||
// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA
|
||||
// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID
|
||||
// 验证格式才能确认是真正的 Claude Code 客户端。
|
||||
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
if metadataUserID == "" {
|
||||
if !claudeCliUserAgentRe.MatchString(userAgent) {
|
||||
return false
|
||||
}
|
||||
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||
return ParseMetadataUserID(metadataUserID) != nil
|
||||
}
|
||||
|
||||
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
||||
@@ -4080,6 +4231,87 @@ func enforceCacheControlLimit(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
|
||||
// 仅修改已经存在的 cache_control,不新增缓存断点。
|
||||
func injectAnthropicCacheControlTTL1h(body []byte) []byte {
|
||||
return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
|
||||
}
|
||||
|
||||
func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
|
||||
if len(body) == 0 || ttl == "" {
|
||||
return body
|
||||
}
|
||||
out := body
|
||||
var paths []string
|
||||
addPath := func(path string, value gjson.Result) {
|
||||
cc := value.Get("cache_control")
|
||||
if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
|
||||
return
|
||||
}
|
||||
if cc.Get("ttl").String() == ttl {
|
||||
return
|
||||
}
|
||||
paths = append(paths, path+".cache_control.ttl")
|
||||
}
|
||||
|
||||
if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
|
||||
paths = append(paths, "cache_control.ttl")
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(body, "system")
|
||||
if system.IsArray() {
|
||||
idx := -1
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("system.%d", idx), block)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if messages.IsArray() {
|
||||
msgIdx := -1
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
msgIdx++
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
contentIdx := -1
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
contentIdx++
|
||||
addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.IsArray() {
|
||||
idx := -1
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("tools.%d", idx), tool)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if next, err := sjson.SetBytes(out, path, ttl); err == nil {
|
||||
out = next
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
|
||||
if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -4144,12 +4376,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
})
|
||||
}
|
||||
|
||||
// OAuth 账号无条件走完整 mimicry,与 Parrot 对齐。
|
||||
// 不再检查 isClaudeCodeRequest —— 即使客户端自称 Claude Code(opencode 等
|
||||
// 第三方工具会伪装 UA / X-App / system prompt),它的伪装往往不完整(缺 billing
|
||||
// block / 工具名混淆 / cache 策略等),被 Anthropic 判为 third-party。
|
||||
// 无条件覆盖不会对真正的 Claude Code 造成问题,因为我们的伪装更完整。
|
||||
shouldMimicClaudeCode := account.IsOAuth()
|
||||
// Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。
|
||||
// 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header,
|
||||
// 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略
|
||||
// (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token
|
||||
// 最低缓存门槛,导致系统级缓存失效)。
|
||||
//
|
||||
// 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。
|
||||
isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
// 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code
|
||||
@@ -4210,6 +4445,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
if candidate, matched := account.ResolveMappedModel(reqModel); matched {
|
||||
mappedModel = candidate
|
||||
mappingSource = "account"
|
||||
} else {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "vertex"
|
||||
}
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(reqModel)
|
||||
if normalized != reqModel {
|
||||
@@ -4224,6 +4471,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
|
||||
body = injectAnthropicCacheControlTTL1h(body)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -5679,6 +5930,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||
if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
|
||||
}
|
||||
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -5865,6 +6120,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
modelID string,
|
||||
reqStream bool,
|
||||
) (*http.Request, error) {
|
||||
vertexBody, err := buildVertexAnthropicRequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, vertexBody)
|
||||
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||
if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
|
||||
continue
|
||||
}
|
||||
wireKey := resolveWireCasing(key)
|
||||
for _, v := range values {
|
||||
addHeaderRaw(req.Header, wireKey, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Del("cookie")
|
||||
req.Header.Del("anthropic-version")
|
||||
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
|
||||
setHeaderRaw(req.Header, "content-type", "application/json")
|
||||
|
||||
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
|
||||
"url": req.URL.String(),
|
||||
"token_type": "service_account",
|
||||
"model": modelID,
|
||||
"stream": strconv.FormatBool(reqStream),
|
||||
})
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||
@@ -6425,6 +6734,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
|
||||
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||
// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
|
||||
// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
|
||||
func sanitizeStreamError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return "unexpected EOF"
|
||||
case errors.Is(err, io.EOF):
|
||||
return "EOF"
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "canceled"
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "deadline exceeded"
|
||||
case errors.Is(err, syscall.ECONNRESET):
|
||||
return "connection reset by peer"
|
||||
case errors.Is(err, syscall.ECONNABORTED):
|
||||
return "connection aborted"
|
||||
case errors.Is(err, syscall.ETIMEDOUT):
|
||||
return "connection timed out"
|
||||
case errors.Is(err, syscall.EPIPE):
|
||||
return "broken pipe"
|
||||
case errors.Is(err, syscall.ECONNREFUSED):
|
||||
return "connection refused"
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
if netErr.Op != "" {
|
||||
return netErr.Op + " timeout"
|
||||
}
|
||||
return "i/o timeout"
|
||||
}
|
||||
if netErr.Op != "" {
|
||||
return netErr.Op + " network error"
|
||||
}
|
||||
}
|
||||
return "upstream connection error"
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
@@ -6862,14 +7214,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
|
||||
// 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":<reason>,"message":<message>}}
|
||||
// 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
|
||||
// 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
sendErrorEvent := func(reason, message string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
if message == "" {
|
||||
message = reason
|
||||
}
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": reason,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
|
||||
body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@@ -6946,9 +7315,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
@@ -7029,10 +7398,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
// 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
|
||||
// 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
|
||||
// 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
|
||||
// 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
|
||||
// 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
|
||||
// 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
|
||||
disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
|
||||
if !c.Writer.Written() {
|
||||
logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": "upstream_disconnected",
|
||||
"message": disconnectMsg,
|
||||
},
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: body,
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
sendErrorEvent("stream_read_error", disconnectMsg)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
line := ev.line
|
||||
@@ -7091,7 +7482,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
sendErrorEvent("stream_timeout")
|
||||
sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
@@ -7333,6 +7724,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
return account.GetCacheTTLOverrideTarget(), true
|
||||
}
|
||||
if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
|
||||
return cacheTTLTarget5m, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
@@ -7369,9 +7773,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
|
||||
@@ -7770,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
@@ -7939,10 +8350,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
applyCacheTTLOverride(&result.Usage, overrideTarget)
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
@@ -8387,7 +8799,8 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
// Pre-filter: strip empty text blocks to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
shouldMimicClaudeCode := account.IsOAuth()
|
||||
isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type upstreamContextTestKey string
|
||||
|
||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
|
||||
require.Equal(t, 3, result.usage.InputTokens)
|
||||
require.Equal(t, 7, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
|
||||
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
|
||||
upstreamCtx, release := detachUpstreamContext(parent)
|
||||
defer release()
|
||||
|
||||
cancel()
|
||||
|
||||
require.NoError(t, upstreamCtx.Err())
|
||||
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
|
||||
}
|
||||
|
||||
@@ -4,9 +4,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
||||
}
|
||||
|
||||
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
|
||||
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
|
||||
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
|
||||
|
||||
// ResponseBody 必须是 Anthropic 标准 error 格式:
|
||||
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
|
||||
// 2) error.type 标记为 upstream_disconnected
|
||||
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
|
||||
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
|
||||
require.Contains(t, extractedMsg, "upstream stream disconnected")
|
||||
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
|
||||
|
||||
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
|
||||
require.NotContains(t, rec.Body.String(), "stream_read_error")
|
||||
}
|
||||
|
||||
// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
|
||||
// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
|
||||
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{
|
||||
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
|
||||
err: io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
|
||||
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
|
||||
|
||||
// 不应被错误地包成 failover error
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
|
||||
|
||||
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
|
||||
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
|
||||
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
|
||||
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
|
||||
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
|
||||
}
|
||||
|
||||
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
|
||||
// failover ResponseBody 或 SSE error 帧返回给客户端。
|
||||
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
|
||||
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||
require.NoError(t, err)
|
||||
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw := &net.OpError{
|
||||
Op: "read",
|
||||
Net: "tcp",
|
||||
Source: src,
|
||||
Addr: dst,
|
||||
Err: syscall.ECONNRESET,
|
||||
}
|
||||
|
||||
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
|
||||
require.Contains(t, raw.Error(), "10.0.0.1")
|
||||
require.Contains(t, raw.Error(), "52.1.2.3")
|
||||
|
||||
got := sanitizeStreamError(raw)
|
||||
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
|
||||
require.NotContains(t, got, "54321", "不得泄露源端口")
|
||||
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
|
||||
require.NotContains(t, got, "443", "不得泄露上游端口")
|
||||
require.Equal(t, "connection reset by peer", got)
|
||||
}
|
||||
|
||||
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
|
||||
{"EOF", io.EOF, "EOF"},
|
||||
{"context canceled", context.Canceled, "canceled"},
|
||||
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
|
||||
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
|
||||
{"EPIPE", syscall.EPIPE, "broken pipe"},
|
||||
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
|
||||
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
|
||||
{"nil 返回空串", nil, ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
|
||||
// 时携带内部地址信息。
|
||||
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||
netErr := &net.OpError{
|
||||
Op: "read",
|
||||
Net: "tcp",
|
||||
Source: src,
|
||||
Addr: dst,
|
||||
Err: syscall.ECONNRESET,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{err: netErr},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.True(t, errors.As(err, &failoverErr))
|
||||
|
||||
body := string(failoverErr.ResponseBody)
|
||||
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
|
||||
require.NotContains(t, body, "54321")
|
||||
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
|
||||
require.NotContains(t, body, "443")
|
||||
// 仍然包含可诊断的根因
|
||||
require.Contains(t, body, "connection reset by peer")
|
||||
require.Contains(t, body, "upstream stream disconnected")
|
||||
}
|
||||
|
||||
@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
}
|
||||
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
||||
return 3
|
||||
case AccountTypeServiceAccount:
|
||||
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
|
||||
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
|
||||
return 999
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case AccountTypeServiceAccount:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return upstreamReq, "x-request-id", nil
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case AccountTypeServiceAccount:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return upstreamReq, "x-request-id", nil
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
default:
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
geminiTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
||||
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
|
||||
type GeminiTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache
|
||||
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||
return "", errors.New("not a gemini oauth or service account")
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return p.getServiceAccountAccessToken(ctx, account)
|
||||
}
|
||||
|
||||
cacheKey := GeminiTokenCacheKey(account)
|
||||
@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
if account != nil && account.Type == AccountTypeServiceAccount {
|
||||
if key, err := parseVertexServiceAccountKey(account); err == nil {
|
||||
return vertexServiceAccountCacheKey(account, key)
|
||||
}
|
||||
}
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "gemini:" + projectID
|
||||
|
||||
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
)
|
||||
|
||||
// openaiResponsesProbeTimeout 是探测请求的超时时长。
|
||||
// 探测必须快速失败——超时不应阻塞账号创建/更新流程。
|
||||
const openaiResponsesProbeTimeout = 8 * time.Second
|
||||
|
||||
// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。
|
||||
// 仅作能力探测,不期望响应内容质量;Stream=false 减少 SSE 解析开销。
|
||||
//
|
||||
// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的
|
||||
// 4xx/5xx(如 400 invalid_request_error / 401 unauthorized / 422 等),
|
||||
// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。
|
||||
func openaiResponsesProbePayload(modelID string) []byte {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
modelID = openai.DefaultTestModel
|
||||
}
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{"type": "input_text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"instructions": openai.DefaultInstructions,
|
||||
"stream": false,
|
||||
})
|
||||
return body
|
||||
}
|
||||
|
||||
// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持
|
||||
// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。
|
||||
//
|
||||
// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。
|
||||
//
|
||||
// 探测策略(参见包文档 internal/pkg/openai_compat):
|
||||
// - 上游 404 / 405 → 不支持,写 false
|
||||
// - 上游 2xx / 其他 4xx(401/422/400 等)/ 5xx → 支持,写 true
|
||||
// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown
|
||||
// (后续请求仍按"现状即证据"默认走 Responses)
|
||||
//
|
||||
// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。
|
||||
//
|
||||
// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了,
|
||||
// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。
|
||||
func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey {
|
||||
// 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID)
|
||||
return
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
probeURL := buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload("")))
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
// 网络层失败:不写标记,保持 unknown,下次重试或由网关 fallback 处理
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
supported := isResponsesEndpointSupportedByStatus(resp.StatusCode)
|
||||
|
||||
if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{
|
||||
openai_compat.ExtraKeyResponsesSupported: supported,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.openai_probe",
|
||||
"probe_done: account_id=%d base_url=%s status=%d supported=%v",
|
||||
accountID, normalizedBaseURL, resp.StatusCode, supported,
|
||||
)
|
||||
}
|
||||
|
||||
// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游
|
||||
// 是否暴露 /v1/responses 端点。
|
||||
//
|
||||
// 关键观察:第三方 OpenAI 兼容上游(DeepSeek/Kimi 等)对未知端点统一返回 404
|
||||
// 或 405;而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段)
|
||||
// 返回 400/422 等业务错误,但端点本身存在。
|
||||
//
|
||||
// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。
|
||||
//
|
||||
// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。
|
||||
func isResponsesEndpointSupportedByStatus(status int) bool {
|
||||
switch status {
|
||||
case http.StatusNotFound, http.StatusMethodNotAllowed:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -53,6 +53,23 @@ const (
|
||||
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
|
||||
)
|
||||
|
||||
var openAIChatGPTInternalUnsupportedFields = []string{
|
||||
"user",
|
||||
"metadata",
|
||||
"prompt_cache_retention",
|
||||
"safety_identifier",
|
||||
"stream_options",
|
||||
}
|
||||
|
||||
var openAICodexOAuthUnsupportedFields = append([]string{
|
||||
"max_output_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
}, openAIChatGPTInternalUnsupportedFields...)
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
@@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
}
|
||||
|
||||
// Strip parameters unsupported by codex models via the Responses API.
|
||||
for _, key := range []string{
|
||||
"max_output_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
// prompt_cache_retention is a newer Responses API parameter (cache TTL).
|
||||
// The ChatGPT internal Codex endpoint rejects it with
|
||||
// "Unsupported parameter: prompt_cache_retention". Defense-in-depth
|
||||
// for any OAuth path that reaches this transform — the Cursor
|
||||
// Responses-shape short-circuit in ForwardAsChatCompletions strips
|
||||
// it earlier too, but we keep this line so other OAuth callers are
|
||||
// equally protected.
|
||||
"prompt_cache_retention",
|
||||
} {
|
||||
// Strip parameters unsupported by ChatGPT internal Codex endpoint.
|
||||
for _, key := range openAICodexOAuthUnsupportedFields {
|
||||
if _, ok := reqBody[key]; ok {
|
||||
delete(reqBody, key)
|
||||
result.Modified = true
|
||||
@@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
reqBody["tool_choice"] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
},
|
||||
"name": name,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
|
||||
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
|
||||
if choiceType == "" {
|
||||
return false
|
||||
}
|
||||
modified := false
|
||||
if choiceType == "function" {
|
||||
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
|
||||
if name == "" {
|
||||
if function, ok := choiceMap["function"].(map[string]any); ok {
|
||||
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
|
||||
choiceMap["name"] = name
|
||||
modified = true
|
||||
}
|
||||
if _, ok := choiceMap["function"]; ok {
|
||||
delete(choiceMap, "function")
|
||||
modified = true
|
||||
}
|
||||
if !codexToolsContainFunctionName(reqBody["tools"], name) {
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
return modified
|
||||
}
|
||||
if codexToolsContainType(reqBody["tools"], choiceType) {
|
||||
return modified
|
||||
}
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
@@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func codexToolsContainFunctionName(rawTools any, name string) bool {
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok || strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
normalizedName := strings.TrimSpace(name)
|
||||
for _, rawTool := range tools {
|
||||
tool, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
|
||||
continue
|
||||
}
|
||||
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
|
||||
if toolName == "" {
|
||||
if function, ok := tool["function"].(map[string]any); ok {
|
||||
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||
}
|
||||
}
|
||||
if toolName == normalizedName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
|
||||
if len(input) == 0 {
|
||||
return input, false
|
||||
@@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// chatgpt.com codex backend (OAuth path) does not persist reasoning
|
||||
// items because applyCodexOAuthTransform forces store=false. Any rs_*
|
||||
// reference replayed in input is guaranteed to 404 upstream
|
||||
// ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
|
||||
if typ == "reasoning" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
|
||||
require.Equal(t, "custom", choice["type"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": "shell"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
choice, ok := reqBody["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function", choice["type"])
|
||||
require.Equal(t, "shell", choice["name"])
|
||||
require.NotContains(t, choice, "function")
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": "missing"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
require.Equal(t, "auto", reqBody["tool_choice"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
@@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
|
||||
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"user": "user_123",
|
||||
"metadata": map[string]any{"trace_id": "abc"},
|
||||
"prompt_cache_retention": "24h",
|
||||
"safety_identifier": "sid",
|
||||
"stream_options": map[string]any{"include_usage": true},
|
||||
"input": []any{
|
||||
map[string]any{"role": "user", "content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
require.True(t, result.Modified)
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
require.NotContains(t, reqBody, field)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
|
||||
// Reasoning items in input[] reference rs_* IDs that were emitted by
|
||||
// chatgpt.com under store=false (forced by applyCodexOAuthTransform).
|
||||
// They are never persisted upstream, so forwarding them produces a
|
||||
// guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
|
||||
// regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
|
||||
|
||||
build := func() []any {
|
||||
return []any{
|
||||
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||
map[string]any{
|
||||
"type": "reasoning",
|
||||
"id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
|
||||
"summary": []any{},
|
||||
},
|
||||
map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
|
||||
}
|
||||
}
|
||||
|
||||
for _, preserve := range []bool{true, false} {
|
||||
preserve := preserve
|
||||
t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
|
||||
filtered := filterCodexInput(build(), preserve)
|
||||
|
||||
for _, raw := range filtered {
|
||||
item, ok := raw.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "reasoning", item["type"],
|
||||
"reasoning items must be dropped from input on the OAuth path")
|
||||
if id, ok := item["id"].(string); ok {
|
||||
require.False(t, strings.HasPrefix(id, "rs_"),
|
||||
"no item carrying an rs_* id should survive the filter")
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check: the non-reasoning items should still be present.
|
||||
gotTypes := make(map[string]int)
|
||||
for _, raw := range filtered {
|
||||
item, ok := raw.(map[string]any)
|
||||
require.True(t, ok)
|
||||
typ, ok := item["type"].(string)
|
||||
require.True(t, ok)
|
||||
gotTypes[typ]++
|
||||
}
|
||||
require.Equal(t, 1, gotTypes["message"])
|
||||
require.Equal(t, 1, gotTypes["function_call"])
|
||||
require.Equal(t, 1, gotTypes["function_call_output"])
|
||||
require.Equal(t, 0, gotTypes["reasoning"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@@ -18,6 +21,51 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAICompatFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
type openAICompatBlockingReadCloser struct {
|
||||
data []byte
|
||||
offset int
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
|
||||
return &openAICompatBlockingReadCloser{
|
||||
data: data,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
|
||||
if r.offset < len(r.data) {
|
||||
n := copy(p, r.data[r.offset:])
|
||||
r.offset += n
|
||||
return n, nil
|
||||
}
|
||||
<-r.closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *openAICompatBlockingReadCloser) Close() error {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||
"",
|
||||
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := "data: [DONE]\n\n"
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
require.Zero(t, result.Usage.InputTokens)
|
||||
require.Zero(t, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
286
backend/internal/service/openai_fast_policy_test.go
Normal file
286
backend/internal/service/openai_fast_policy_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIFastPolicyRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
if s.values == nil {
|
||||
s.values = map[string]string{}
|
||||
}
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
||||
t.Helper()
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
if settings != nil {
|
||||
raw, err := json.Marshal(settings)
|
||||
require.NoError(t, err)
|
||||
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
||||
}
|
||||
return &OpenAIGatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
||||
// 是用户级开关,与 model 正交。
|
||||
// gpt-5.5 + priority → filter
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5-turbo → filter
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5 + flex → pass (tier doesn't match)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
|
||||
// empty tier → pass
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is not allowed",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionBlock, action)
|
||||
require.Equal(t, "fast mode is not allowed", msg)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeOAuth,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
|
||||
// OAuth account → rule matches
|
||||
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// API Key account → rule skipped → pass
|
||||
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// gpt-5.5 fast → service_tier stripped
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// Client sending "fast" (alias for priority) also filtered
|
||||
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
||||
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// No service_tier → no-op
|
||||
body = []byte(`{"model":"gpt-5.5"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
||||
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
||||
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
||||
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err, "tier %q should pass without error", tier)
|
||||
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
||||
"tier %q should be preserved in body under default rule", tier)
|
||||
}
|
||||
|
||||
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
||||
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
||||
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
||||
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
||||
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`,
|
||||
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
||||
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
||||
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
||||
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
||||
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// normalize 阶段会将未知值剥离
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
|
||||
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
||||
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.Error(t, err)
|
||||
var blocked *OpenAIFastBlockedError
|
||||
require.True(t, errors.As(err, &blocked))
|
||||
require.Contains(t, blocked.Message, "fast mode is blocked")
|
||||
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
||||
}
|
||||
|
||||
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
// Invalid action rejected
|
||||
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: "bogus",
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Invalid service_tier rejected
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: "turbo",
|
||||
Action: BetaPolicyActionPass,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Valid settings persisted
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got.Rules, 1)
|
||||
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
||||
}
|
||||
1074
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
1074
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
|
||||
counter := &openAI403CounterResetStub{}
|
||||
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
||||
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
rateLimitService: rateLimitSvc,
|
||||
}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
svc.rateLimitService = rateLimitSvc
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{},
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage_reset_403",
|
||||
Model: "gpt-5.1",
|
||||
},
|
||||
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2001},
|
||||
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{777}, counter.resetCalls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{
|
||||
|
||||
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||
// the response back to Chat Completions format. All account types (OAuth and API
|
||||
// Key) go through the Responses API conversion path since the upstream only
|
||||
// exposes the /v1/responses endpoint.
|
||||
// the response back to Chat Completions format.
|
||||
//
|
||||
// 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
|
||||
// 端点——这在 OAuth(ChatGPT 内部 API 仅支持 Responses)和官方 APIKey 账号上是
|
||||
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
|
||||
// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。
|
||||
//
|
||||
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI):
|
||||
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
|
||||
// 直转上游 /v1/chat/completions,不做协议转换
|
||||
// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
|
||||
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
|
||||
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
@@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
promptCacheKey string,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
// 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。
|
||||
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
|
||||
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Chat Completions request
|
||||
@@ -171,6 +188,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
}
|
||||
|
||||
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -178,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@@ -337,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@@ -448,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -456,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
@@ -485,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
|
||||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai chat_completions stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunks) > 0 {
|
||||
if len(chunks) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
|
||||
for _, chunk := range finalChunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Send [DONE] sentinel
|
||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||
c.Writer.Flush()
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
@@ -544,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// Determine keepalive interval
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@@ -552,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
|
||||
// No keepalive: fast synchronous path
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// With keepalive: goroutine + channel + select
|
||||
@@ -573,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@@ -584,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@@ -594,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai chat_completions stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@@ -626,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
|
||||
//
|
||||
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
|
||||
// (originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id),
|
||||
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
|
||||
// OpenAI 兼容上游会造成:
|
||||
// - 完全忽略(多数友好厂商)——隐性污染上游统计
|
||||
// - 400 "unknown parameter"(严格上游)——可见错误
|
||||
//
|
||||
// 这里仅放行通用 HTTP header;content-type / authorization / accept 由上下文
|
||||
// 显式设置,不依赖透传。
|
||||
//
|
||||
// 参见决策记录:
|
||||
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
|
||||
var openaiCCRawAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"user-agent": true,
|
||||
}
|
||||
|
||||
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
|
||||
// `{base_url}/v1/chat/completions`,**不**做 CC↔Responses 协议转换。
|
||||
//
|
||||
// 适用场景:account.platform=openai && account.type=apikey && 上游已被探测确认
|
||||
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
|
||||
//
|
||||
// 与 ForwardAsChatCompletions 的关键差异:
|
||||
//
|
||||
// - 不调用 apicompat.ChatCompletionsToResponses,body 仅做模型 ID 改写
|
||||
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
|
||||
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
|
||||
// - 非流式响应 JSON 直接透传,仅按需提取 usage
|
||||
// - 不应用 codex OAuth transform(APIKey 路径无 OAuth)
|
||||
// - 不注入 prompt_cache_key(OAuth 专属机制)
|
||||
//
|
||||
// 调用入口:openai_gateway_chat_completions.go::ForwardAsChatCompletions
|
||||
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
|
||||
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse minimal fields needed for routing/billing
|
||||
originalModel := gjson.GetBytes(body, "model").String()
|
||||
if originalModel == "" {
|
||||
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return nil, fmt.Errorf("missing model in request")
|
||||
}
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
|
||||
serviceTier := extractOpenAIServiceTierFromBody(body)
|
||||
|
||||
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
|
||||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
|
||||
// 3. Rewrite model in body (no protocol conversion)
|
||||
upstreamBody := body
|
||||
if upstreamModel != originalModel {
|
||||
upstreamBody = ReplaceModelInBody(body, upstreamModel)
|
||||
}
|
||||
|
||||
// 4. Apply OpenAI fast policy on the CC body
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
upstreamBody = updatedBody
|
||||
if clientStream {
|
||||
var usageErr error
|
||||
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
|
||||
if usageErr != nil {
|
||||
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("billing_model", billingModel),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
zap.Bool("stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Build upstream request
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 6. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 7. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
// 8. Forward response
|
||||
if clientStream {
|
||||
return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
|
||||
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage(包括
|
||||
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
|
||||
//
|
||||
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
|
||||
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage,
|
||||
// 让级联代理或下游计费系统也能拿到完整用量。
|
||||
func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if payload, ok := extractOpenAISSEDataLine(line); ok {
|
||||
trimmedPayload := strings.TrimSpace(payload)
|
||||
if trimmedPayload != "[DONE]" {
|
||||
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
if firstTokenMs == nil && !usageOnlyChunk {
|
||||
elapsed := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &elapsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
|
||||
zap.Error(werr),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions raw: stream read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
|
||||
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
|
||||
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
|
||||
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
|
||||
if strings.TrimSpace(payload) == "" {
|
||||
return false
|
||||
}
|
||||
if !gjson.Get(payload, "usage").Exists() {
|
||||
return false
|
||||
}
|
||||
choices := gjson.Get(payload, "choices")
|
||||
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
|
||||
}
|
||||
|
||||
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
|
||||
// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时),
|
||||
// 但上游可能在多个 chunk 中重复——总是用最新值。
|
||||
func extractCCStreamUsage(payload string) *OpenAIUsage {
|
||||
usageResult := gjson.Get(payload, "usage")
|
||||
if !usageResult.Exists() || !usageResult.IsObject() {
|
||||
return nil
|
||||
}
|
||||
u := OpenAIUsage{
|
||||
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
|
||||
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
|
||||
}
|
||||
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
u.CacheReadInputTokens = int(cached.Int())
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
|
||||
func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||
}
|
||||
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||
}
|
||||
|
||||
var ccResp apicompat.ChatCompletionsResponse
|
||||
var usage OpenAIUsage
|
||||
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: ccResp.Usage.PromptTokens,
|
||||
OutputTokens: ccResp.Usage.CompletionTokens,
|
||||
}
|
||||
if ccResp.Usage.PromptTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
c.Writer.Header().Set("Content-Type", ct)
|
||||
} else {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
|
||||
//
|
||||
// - base 已是 /chat/completions:原样返回
|
||||
// - base 以 /v1 结尾:追加 /chat/completions
|
||||
// - 其他情况:追加 /v1/chat/completions
|
||||
//
|
||||
// 与 buildOpenAIResponsesURL 是姐妹函数。
|
||||
func buildOpenAIChatCompletionsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/chat/completions") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1") {
|
||||
return normalized + "/chat/completions"
|
||||
}
|
||||
return normalized + "/v1/chat/completions"
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
// 已是 /chat/completions:原样返回
|
||||
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
|
||||
// 以 /v1 结尾:追加 /chat/completions
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
|
||||
// 其他情况:追加 /v1/chat/completions
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
|
||||
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
|
||||
// 第三方上游常见形式
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
|
||||
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
|
||||
// 带空白字符
|
||||
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIChatCompletionsURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
|
||||
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
|
||||
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
|
||||
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
|
||||
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
|
||||
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIResponsesURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
require.Contains(t, rec.Body.String(), `"usage"`)
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 17, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
|
||||
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader("toolong")),
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
|
||||
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||
|
||||
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
|
||||
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 101,
|
||||
Name: "raw-openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "http://upstream.example",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIChatFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -19,8 +42,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "flex", req.ServiceTier)
|
||||
|
||||
// OpenAI 官方合法 tier 应被透传保留。
|
||||
req.ServiceTier = "auto"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "auto", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "default"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "default", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "scale"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "scale", req.ServiceTier)
|
||||
|
||||
// 真未知值仍被剥离。
|
||||
req.ServiceTier = "turbo"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Empty(t, req.ServiceTier)
|
||||
}
|
||||
|
||||
@@ -37,8 +74,264 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
require.Equal(t, "flex", tier)
|
||||
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "auto", tier)
|
||||
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "default", tier)
|
||||
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "scale", tier)
|
||||
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// 真未知值才会被删除。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||
"",
|
||||
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := "data: [DONE]\n\n"
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
require.Zero(t, result.Usage.InputTokens)
|
||||
require.Zero(t, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@@ -143,6 +144,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
}
|
||||
|
||||
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
|
||||
// on the body-level service_tier field (priority/flex).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -150,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@@ -283,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@@ -367,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
|
||||
switch strings.TrimSpace(eventType) {
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAICompatDoneSentinelLine(line string) bool {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
|
||||
resp *http.Response,
|
||||
logPrefix string,
|
||||
requestID string,
|
||||
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
var usage OpenAIUsage
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, usage, acc, errors.New("upstream response body is nil")
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var timeoutCh <-chan time.Time
|
||||
var timeoutTimer *time.Timer
|
||||
resetTimeout := func() {
|
||||
if streamInterval <= 0 {
|
||||
return
|
||||
}
|
||||
if timeoutTimer == nil {
|
||||
timeoutTimer = time.NewTimer(streamInterval)
|
||||
timeoutCh = timeoutTimer.C
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timeoutTimer.Reset(streamInterval)
|
||||
}
|
||||
stopTimeout := func() {
|
||||
if timeoutTimer == nil {
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
resetTimeout()
|
||||
defer stopTimeout()
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case events <- scanEvent{line: scanner.Text()}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
select {
|
||||
case events <- scanEvent{err: err}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
resetTimeout()
|
||||
if ev.err != nil {
|
||||
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.L().Warn(logPrefix+": read error",
|
||||
zap.Error(ev.err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
return nil, usage, acc, ev.err
|
||||
}
|
||||
|
||||
if isOpenAICompatDoneSentinelLine(ev.line) {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(ev.line)
|
||||
if !ok || payload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn(logPrefix+": failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
|
||||
if event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
return event.Response, usage, acc, nil
|
||||
}
|
||||
|
||||
case <-timeoutCh:
|
||||
_ = resp.Body.Close()
|
||||
logger.L().Warn(logPrefix+": data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
@@ -396,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -404,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
@@ -419,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
@@ -436,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
|
||||
// Convert to Anthropic events
|
||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
if len(events) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
@@ -496,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@@ -504,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
@@ -525,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@@ -536,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@@ -546,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
@@ -555,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai messages stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@@ -580,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
@@ -597,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
|
||||
if usage == nil {
|
||||
return OpenAIUsage{}
|
||||
}
|
||||
result := OpenAIUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
}
|
||||
if usage.InputTokensDetails != nil {
|
||||
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
@@ -185,6 +186,56 @@ func max(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage",
|
||||
Usage: OpenAIUsage{},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
|
||||
require.Zero(t, usageRepo.lastLog.InputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.OutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageCount)
|
||||
require.Zero(t, usageRepo.lastLog.InputCost)
|
||||
require.Zero(t, usageRepo.lastLog.OutputCost)
|
||||
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||
groupID := int64(11)
|
||||
groupRate := 1.4
|
||||
@@ -826,18 +877,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
t.Run("openai official tiers preserved", func(t *testing.T) {
|
||||
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
|
||||
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
|
||||
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
|
||||
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
|
||||
got := normalizeOpenAIServiceTier(tier)
|
||||
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
|
||||
require.Equal(t, tier, *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
@@ -845,7 +907,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
|
||||
@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
|
||||
resolver *ModelPricingResolver
|
||||
channelService *ChannelService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
settingService *SettingService
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
|
||||
resolver *ModelPricingResolver,
|
||||
channelService *ChannelService,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
settingService *SettingService,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
|
||||
resolver: resolver,
|
||||
channelService: channelService,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
settingService: settingService,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
|
||||
return sessionID
|
||||
}
|
||||
|
||||
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
|
||||
// client session signals. It intentionally skips content-derived fallback and is
|
||||
// used by stateless endpoints such as /v1/images.
|
||||
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
|
||||
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
||||
return currentHash
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
@@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = deriveOpenAIContentSessionSeed(body)
|
||||
}
|
||||
@@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
|
||||
// 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
|
||||
// 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
|
||||
// fast 时在此生效。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
|
||||
// normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
|
||||
// chat-completions / messages 入口保持一致,避免不同入口因为模型
|
||||
// 维度不同而出现 whitelist 命中差异。
|
||||
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
|
||||
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
|
||||
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
|
||||
// 行为,这里手工实现等效逻辑。
|
||||
if rawTier, ok := reqBody["service_tier"].(string); ok {
|
||||
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
|
||||
}
|
||||
blocked := &OpenAIFastBlockedError{Message: msg}
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
return nil, blocked
|
||||
case BetaPolicyActionFilter:
|
||||
delete(reqBody, "service_tier")
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
default:
|
||||
// pass:若客户端传的是别名 "fast",归一化为 "priority"
|
||||
// 后写回 body,确保上游收到的是其能识别的规范值。
|
||||
if normTier != rawTier {
|
||||
reqBody["service_tier"] = normTier
|
||||
bodyModified = true
|
||||
markPatchSet("service_tier", normTier)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
serializedByPatch := false
|
||||
@@ -2533,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
body = sanitizedBody
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
|
||||
// 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
|
||||
// OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
|
||||
// 这样可以与 chat-completions / messages / native /responses 入口的
|
||||
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
|
||||
// model 字段时退回 reqModel。
|
||||
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||
if policyModel == "" {
|
||||
policyModel = reqModel
|
||||
}
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
body = updatedBody
|
||||
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
@@ -2764,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@@ -4008,8 +4096,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
// Track downstream writes separately from upstream reads: pre-output failover
|
||||
// can buffer response.created / response.in_progress, so keepalive must be
|
||||
// based on downstream idle time.
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||
@@ -4041,6 +4131,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return
|
||||
}
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
@@ -4071,6 +4162,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
} else if hadBufferedData {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
@@ -4114,8 +4206,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if streamFailoverErr != nil {
|
||||
return
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
|
||||
@@ -4170,6 +4260,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4197,6 +4288,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4283,7 +4375,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
|
||||
@@ -4294,6 +4386,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4372,7 +4466,8 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
return
|
||||
}
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" {
|
||||
if eventType != "response.completed" && eventType != "response.done" &&
|
||||
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4519,7 +4614,7 @@ func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return eventType, []byte(data), true
|
||||
}
|
||||
}
|
||||
@@ -4834,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
|
||||
}
|
||||
|
||||
normalized := []byte(`{}`)
|
||||
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
|
||||
// Keep the current Codex /compact schema while still dropping request-scoped
|
||||
// fields such as prompt_cache_key, store, and stream.
|
||||
for _, field := range []string{
|
||||
"model",
|
||||
"input",
|
||||
"instructions",
|
||||
"tools",
|
||||
"parallel_tool_calls",
|
||||
"reasoning",
|
||||
"text",
|
||||
"previous_response_id",
|
||||
} {
|
||||
value := gjson.GetBytes(body, field)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
@@ -4935,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||
}
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
||||
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
@@ -5447,7 +5546,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
|
||||
}
|
||||
|
||||
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
||||
// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
|
||||
// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
|
||||
// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false
|
||||
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
|
||||
if len(body) == 0 {
|
||||
return body, false, nil
|
||||
@@ -5456,6 +5556,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
|
||||
normalized := body
|
||||
changed := false
|
||||
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
if value := gjson.GetBytes(normalized, field); !value.Exists() {
|
||||
continue
|
||||
}
|
||||
next, err := sjson.DeleteBytes(normalized, field)
|
||||
if err != nil {
|
||||
return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
|
||||
}
|
||||
normalized = next
|
||||
changed = true
|
||||
}
|
||||
|
||||
if compact {
|
||||
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
|
||||
next, err := sjson.DeleteBytes(normalized, "store")
|
||||
@@ -5560,14 +5672,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
|
||||
if value == "fast" {
|
||||
value = "priority"
|
||||
}
|
||||
// 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
|
||||
// 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
|
||||
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
|
||||
// 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
|
||||
switch value {
|
||||
case "priority", "flex":
|
||||
case "priority", "flex", "auto", "default", "scale":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
|
||||
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
|
||||
type OpenAIFastBlockedError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
|
||||
|
||||
// evaluateOpenAIFastPolicy returns the action and error message that should be
|
||||
// applied for a request with the given account/model/service_tier. When the
|
||||
// policy service is unavailable or no rule matches, it returns
|
||||
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
|
||||
//
|
||||
// Matching rules:
|
||||
// - Scope filters by account type (all / oauth / apikey / bedrock)
|
||||
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
|
||||
// - ModelWhitelist narrows the rule to specific models; FallbackAction
|
||||
// handles the non-matching case (default: pass)
|
||||
//
|
||||
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
|
||||
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
|
||||
// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
|
||||
// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
|
||||
// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
|
||||
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
|
||||
// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
|
||||
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
|
||||
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
|
||||
if s == nil || s.settingService == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
tier := strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
if tier == "" {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings := openAIFastPolicySettingsFromContext(ctx)
|
||||
if settings == nil {
|
||||
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
|
||||
if err != nil || fetched == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings = fetched
|
||||
}
|
||||
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
|
||||
}
|
||||
|
||||
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
|
||||
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
|
||||
// the settingService on every frame. See WSSession entry and
|
||||
// openAIFastPolicySettingsFromContext for the caching glue.
|
||||
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
|
||||
if settings == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
isOAuth := account != nil && account.IsOAuth()
|
||||
isBedrock := account != nil && account.IsBedrock()
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
|
||||
continue
|
||||
}
|
||||
eff := BetaPolicyRule{
|
||||
Action: rule.Action,
|
||||
ErrorMessage: rule.ErrorMessage,
|
||||
ModelWhitelist: rule.ModelWhitelist,
|
||||
FallbackAction: rule.FallbackAction,
|
||||
FallbackErrorMessage: rule.FallbackErrorMessage,
|
||||
}
|
||||
return resolveRuleAction(eff, model)
|
||||
}
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
|
||||
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
|
||||
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
|
||||
//
|
||||
// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
|
||||
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
|
||||
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
|
||||
// 可以通过踢断 session 强制刷新。
|
||||
type openAIFastPolicyCtxKeyType struct{}
|
||||
|
||||
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
|
||||
|
||||
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
|
||||
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
|
||||
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
|
||||
if ctx == nil || settings == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
|
||||
}
|
||||
|
||||
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
|
||||
// body. When action=filter it removes the service_tier field; when
|
||||
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
|
||||
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
|
||||
// rewriting the body so the upstream receives a slug it recognizes.
|
||||
//
|
||||
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
|
||||
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
|
||||
// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
|
||||
// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
|
||||
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(body, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return body, &OpenAIFastBlockedError{Message: msg}
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("strip service_tier from body: %w", err)
|
||||
}
|
||||
return trimmed, nil
|
||||
default:
|
||||
// pass:把别名(如 "fast")写回为规范值("priority")。
|
||||
if normTier == rawTier {
|
||||
return body, nil
|
||||
}
|
||||
updated, err := sjson.SetBytes(body, "service_tier", normTier)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
}
|
||||
|
||||
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
|
||||
// request blocked by the OpenAI fast policy.
|
||||
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
|
||||
if c == nil || err == nil {
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
|
||||
// against a single client→upstream WebSocket frame whose top-level
|
||||
// "type"=="response.create". It mirrors the HTTP-side
|
||||
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
|
||||
// WS payload:
|
||||
//
|
||||
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
|
||||
// - filter: returns a copy with top-level service_tier removed
|
||||
// - block: returns (frame, *OpenAIFastBlockedError)
|
||||
//
|
||||
// Only frames whose "type" field strictly equals "response.create" are
|
||||
// inspected/mutated. Any other frame type — including the empty string —
|
||||
// passes through untouched. The OpenAI Realtime client-event spec requires
|
||||
// "type" to be set, so an empty type is treated as a malformed frame we do
|
||||
// not police; the upstream is the source of truth for rejecting it.
|
||||
//
|
||||
// service_tier lives at the top level of response.create — same as the
|
||||
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
|
||||
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
|
||||
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
|
||||
// to inspect / strip the top-level field; there is no nested form in the
|
||||
// schema today.
|
||||
//
|
||||
// The caller is responsible for choosing the upstream model passed in —
|
||||
// this helper does not re-derive it.
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
model string,
|
||||
frame []byte,
|
||||
) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if len(frame) == 0 {
|
||||
return frame, nil, nil
|
||||
}
|
||||
if !gjson.ValidBytes(frame) {
|
||||
return frame, nil, nil
|
||||
}
|
||||
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
|
||||
// Strict match: only response.create is policy-checked. Empty / other
|
||||
// types pass through untouched so we never accidentally strip fields
|
||||
// from response.cancel, conversation.item.create, or any future
|
||||
// client-event the spec adds. The Realtime spec requires "type" on
|
||||
// every client event, so an empty type is malformed input — let the
|
||||
// upstream reject it rather than guessing at our layer.
|
||||
if frameType != "response.create" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(frame, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return frame, &OpenAIFastBlockedError{Message: msg}, nil
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
|
||||
if err != nil {
|
||||
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
|
||||
}
|
||||
return trimmed, nil, nil
|
||||
default:
|
||||
return frame, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
|
||||
// server-emitted error event. Matches the loose "evt_<rand>" convention used
|
||||
// by upstream Realtime servers; the exact value is not load-bearing and is
|
||||
// only required for client-side log correlation. We reuse the existing
|
||||
// google/uuid dependency rather than pulling a new one.
|
||||
func newOpenAIFastPolicyWSEventID() string {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
// Extremely unlikely; fall back to a fixed prefix so the field is
|
||||
// still non-empty and the schema stays self-consistent.
|
||||
return "evt_openai_fast_policy"
|
||||
}
|
||||
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
|
||||
// canonical form, mirroring what real Realtime traces look like.
|
||||
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
|
||||
}
|
||||
|
||||
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
|
||||
// style "error" event payload for a request blocked by the OpenAI fast
|
||||
// policy. The shape mirrors Realtime error events as observed in upstream
|
||||
// traces and per the spec's server "error" event:
|
||||
//
|
||||
// {
|
||||
// "event_id": "evt_<random>",
|
||||
// "type": "error",
|
||||
// "error": {
|
||||
// "type": "invalid_request_error",
|
||||
// "code": "policy_violation",
|
||||
// "message": "..."
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// event_id lets clients correlate the rejection in their logs; "code" gives
|
||||
// programmatic clients a stable identifier (HTTP-side equivalent is the
|
||||
// 403 permission_error JSON body).
|
||||
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
eventID := newOpenAIFastPolicyWSEventID()
|
||||
payload, mErr := json.Marshal(map[string]any{
|
||||
"event_id": eventID,
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"code": "policy_violation",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
if mErr != nil {
|
||||
// Fallback to a minimal hand-rolled payload; Marshal of the literal
|
||||
// shape above should never fail in practice.
|
||||
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
|
||||
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
|
||||
return body, false, nil
|
||||
|
||||
@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := &OpenAIGatewayService{}
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
|
||||
|
||||
t.Run("stateless image body stays unstuck", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
|
||||
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("header overrides body", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
c.Request.Header.Set("session_id", "header-session")
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1117,6 +1152,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T)
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 1,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
}
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
require.Contains(t, rec.Body.String(), "response.completed")
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1336,6 +1411,41 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 2, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1657,6 +1767,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAICompactRequestBody(body)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
|
||||
require.True(t, gjson.GetBytes(normalized, "tools").Exists())
|
||||
require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
|
||||
require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
|
||||
require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
|
||||
require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
var usage OpenAIUsage
|
||||
imageCount := parsed.N
|
||||
var firstTokenMs *int
|
||||
if parsed.Stream {
|
||||
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
var firstTokenMs *int
|
||||
var fallbackBody bytes.Buffer
|
||||
fallbackBytes := int64(0)
|
||||
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
seenSSEData := false
|
||||
fallbackTooLarge := false
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
@@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
|
||||
if data != "" && data != "[DONE]" {
|
||||
seenSSEData = true
|
||||
fallbackBody.Reset()
|
||||
fallbackBytes = 0
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
} else if !seenSSEData && !fallbackTooLarge {
|
||||
fallbackBytes += int64(len(line))
|
||||
if fallbackBytes <= fallbackLimit {
|
||||
_, _ = fallbackBody.Write(line)
|
||||
} else {
|
||||
fallbackTooLarge = true
|
||||
fallbackBody.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
if !seenSSEData && fallbackBody.Len() > 0 {
|
||||
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||
if len(body) > 0 {
|
||||
mergeOpenAIUsage(&usage, body)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
|
||||
return count
|
||||
}
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return 0
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
|
||||
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
|
||||
return 0
|
||||
}
|
||||
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||
if dst == nil {
|
||||
return
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
|
||||
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/edits",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
}
|
||||
|
||||
type openAIImageTestSSEEvent struct {
|
||||
Name string
|
||||
Data string
|
||||
@@ -371,6 +391,227 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, "gpt-image-2", result.Model)
|
||||
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||
|
||||
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_stream_json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 12, result.Usage.InputTokens)
|
||||
require.Equal(t, 21, result.Usage.OutputTokens)
|
||||
require.Equal(t, 9, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||
|
||||
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||
require.NoError(t, writer.WriteField("prompt", "replace background"))
|
||||
imagePart, err := writer.CreateFormFile("image", "source.png")
|
||||
require.NoError(t, err)
|
||||
_, err = imagePart.Write([]byte("png-image-content"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_edit_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1/",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
|
||||
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
|
||||
require.Contains(t, string(upstream.lastBody), `name="model"`)
|
||||
require.Contains(t, string(upstream.lastBody), "gpt-image-2")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
|
||||
@@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
|
||||
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
@@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
|
||||
}
|
||||
require.True(t, gjson.GetBytes(normalized, "stream").Bool())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Bool())
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.False(t, gjson.GetBytes(normalized, "user").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||
}
|
||||
@@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
|
||||
|
||||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||||
type OpenAIWSIngressHooks struct {
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
// InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
|
||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||
InitialRequestModel string
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSLogValue(value string) string {
|
||||
@@ -1366,16 +1369,27 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
|
||||
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled bool,
|
||||
turn int,
|
||||
hasFunctionCallOutput bool,
|
||||
signals ToolContinuationSignals,
|
||||
currentPreviousResponseID string,
|
||||
expectedPreviousResponseID string,
|
||||
) bool {
|
||||
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
|
||||
if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
||||
return false
|
||||
}
|
||||
if signals.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
// If the client already sent the actual tool-call context, treat this as
|
||||
// a full replay / self-contained continuation payload rather than
|
||||
// downgrading it into an inferred delta continuation. item_reference alone
|
||||
// is not enough on the store=false WS path: it still needs a valid prior
|
||||
// response anchor so upstream can resolve the referenced function_call.
|
||||
if signals.HasToolCallContext {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||||
}
|
||||
|
||||
@@ -2366,6 +2380,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
|
||||
// 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
|
||||
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
|
||||
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
|
||||
if s.settingService != nil {
|
||||
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
|
||||
ctx = withOpenAIFastPolicyContext(ctx, settings)
|
||||
}
|
||||
}
|
||||
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||
ingressMode := OpenAIWSIngressModeCtxPool
|
||||
@@ -2524,6 +2547,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
normalized = next
|
||||
}
|
||||
|
||||
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||
// single integration point for all WS ingress turns (first + follow-up
|
||||
// frames flow through here).
|
||||
//
|
||||
// Model fallback: parseClientPayload above rejects any frame whose
|
||||
// "model" field is missing (line ~2493-2500), so by the time we
|
||||
// reach this point upstreamModel is always derived from a non-empty
|
||||
// per-frame model. The capturedSessionModel fallback used in the
|
||||
// passthrough adapter is therefore not needed in this path.
|
||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||
if policyErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||
}
|
||||
if blocked != nil {
|
||||
// Send a Realtime-style error event to the client first, then
|
||||
// signal the handler to close the connection with PolicyViolation.
|
||||
// We intentionally do NOT forward this frame upstream.
|
||||
//
|
||||
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
|
||||
// the underlying bufio writer before returning (write.go:42 →
|
||||
// 307-311), and the subsequent close handshake re-acquires the
|
||||
// same writeFrameMu, so the error event is guaranteed to reach
|
||||
// the kernel send buffer before any close frame is queued.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes != nil {
|
||||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancel()
|
||||
}
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
blocked.Message,
|
||||
blocked,
|
||||
)
|
||||
}
|
||||
normalized = policyApplied
|
||||
|
||||
return openAIWSClientPayload{
|
||||
payloadRaw: normalized,
|
||||
rawForHash: trimmed,
|
||||
@@ -3132,13 +3193,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
skipBeforeTurn = false
|
||||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||||
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||||
toolSignals := ToolContinuationSignals{
|
||||
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
|
||||
}
|
||||
if toolSignals.HasFunctionCallOutput {
|
||||
var currentReqBody map[string]any
|
||||
if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
|
||||
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
|
||||
}
|
||||
}
|
||||
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
|
||||
// store=false + function_call_output 场景必须有续链锚点。
|
||||
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
||||
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled,
|
||||
turn,
|
||||
hasFunctionCallOutput,
|
||||
toolSignals,
|
||||
currentPreviousResponseID,
|
||||
expectedPrev,
|
||||
) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user