mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bf2a7b88a | ||
|
|
40feb86ba4 | ||
|
|
f972a2faf2 | ||
|
|
55a7fa1e07 | ||
|
|
5e54d492be | ||
|
|
8d6d31545f | ||
|
|
17ced6b73a | ||
|
|
7f8f3fe0dd | ||
|
|
46f06b2498 | ||
|
|
7ce5b83215 | ||
|
|
27cad10d30 | ||
|
|
ff6fa0203d | ||
|
|
f7c13af11f | ||
|
|
28dc34b6a3 | ||
|
|
4d676dddd1 | ||
|
|
93d91e20b9 | ||
|
|
63ef23108c | ||
|
|
d78478e866 | ||
|
|
bf43fb4e38 | ||
|
|
a16c66500f | ||
|
|
4b6954f9f0 | ||
|
|
da4b078df2 | ||
|
|
7452fad820 | ||
|
|
4c474616b9 | ||
|
|
6327573534 | ||
|
|
04b2866f65 | ||
|
|
b0a2252ed1 | ||
|
|
30f55a1f72 | ||
|
|
3d4ca5e8d1 | ||
|
|
0537a490f0 | ||
|
|
ca5d029e7c | ||
|
|
1eca03432a | ||
|
|
53b24bc2d8 | ||
|
|
a161f9d045 | ||
|
|
c5a1a82223 | ||
|
|
2ab6b34fd1 | ||
|
|
764afbe37a | ||
|
|
25c7b0d9f4 | ||
|
|
f422ac6dcc | ||
|
|
54de4e008c | ||
|
|
65c27d2c69 | ||
|
|
53f919f8f0 | ||
|
|
c92b88e34a | ||
|
|
ed0c85a17e | ||
|
|
9fe02bba7e | ||
|
|
615557ec20 | ||
|
|
3f05ef2ae3 | ||
|
|
3022090365 | ||
|
|
798fd673e9 | ||
|
|
c056db740d | ||
|
|
6d11f9ed77 | ||
|
|
489a4d934e |
@@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
|
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||||
|
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
|
||||||
|
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
|
||||||
|
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## Ecosystem
|
## Ecosystem
|
||||||
|
|||||||
@@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
|||||||
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。</td>
|
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||||
|
<td>感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
|
||||||
|
同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
|
||||||
|
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## 生态项目
|
## 生态项目
|
||||||
|
|||||||
@@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
|||||||
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
|
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||||
|
<td>PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
|
||||||
|
エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
|
||||||
|
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## エコシステム
|
## エコシステム
|
||||||
|
|||||||
BIN
assets/partners/logos/pateway.png
Normal file
BIN
assets/partners/logos/pateway.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.0 KiB |
@@ -1 +1 @@
|
|||||||
0.1.118
|
0.1.119
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
|
||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||||
@@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
@@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
channelRepository := repository.NewChannelRepository(db)
|
channelRepository := repository.NewChannelRepository(db)
|
||||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||||
@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
|
|||||||
@@ -26,11 +26,12 @@ const (
|
|||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
const (
|
const (
|
||||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
|
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
|
|||||||
|
|
||||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||||
type BulkUpdateAccountsRequest struct {
|
type BulkUpdateAccountsRequest struct {
|
||||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
Name string `json:"name"`
|
Filters *BulkUpdateAccountFilters `json:"filters"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
Name string `json:"name"`
|
||||||
Concurrency *int `json:"concurrency"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Priority *int `json:"priority"`
|
Concurrency *int `json:"concurrency"`
|
||||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
Priority *int `json:"priority"`
|
||||||
LoadFactor *int `json:"load_factor"`
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
LoadFactor *int `json:"load_factor"`
|
||||||
Schedulable *bool `json:"schedulable"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
Schedulable *bool `json:"schedulable"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
Extra map[string]any `json:"extra"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
Extra map[string]any `json:"extra"`
|
||||||
|
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkUpdateAccountFilters struct {
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Group string `json:"group"`
|
||||||
|
Search string `json:"search"`
|
||||||
|
PrivacyMode string `json:"privacy_mode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||||
@@ -1369,6 +1379,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(req.AccountIDs) == 0 && req.Filters == nil {
|
||||||
|
response.BadRequest(c, "account_ids or filters is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||||
sanitizeExtraBaseRPM(req.Extra)
|
sanitizeExtraBaseRPM(req.Extra)
|
||||||
|
|
||||||
@@ -1394,6 +1408,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
|
|
||||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||||
AccountIDs: req.AccountIDs,
|
AccountIDs: req.AccountIDs,
|
||||||
|
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
@@ -1429,6 +1444,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
|
||||||
|
if filters == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &service.BulkUpdateAccountFilters{
|
||||||
|
Platform: filters.Platform,
|
||||||
|
Type: filters.Type,
|
||||||
|
Status: filters.Status,
|
||||||
|
Group: filters.Group,
|
||||||
|
Search: filters.Search,
|
||||||
|
PrivacyMode: filters.PrivacyMode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ========== OAuth Handlers ==========
|
// ========== OAuth Handlers ==========
|
||||||
|
|
||||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||||
|
|||||||
@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
|
|||||||
require.Equal(t, float64(2), data["success"])
|
require.Equal(t, float64(2), data["success"])
|
||||||
require.Equal(t, float64(0), data["failed"])
|
require.Equal(t, float64(0), data["failed"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
|
||||||
|
adminSvc := newStubAdminService()
|
||||||
|
router := setupAccountMixedChannelRouter(adminSvc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"filters": map[string]any{
|
||||||
|
"platform": "openai",
|
||||||
|
"type": "oauth",
|
||||||
|
"status": "active",
|
||||||
|
"group": "12",
|
||||||
|
"privacy_mode": "blocked",
|
||||||
|
"search": "bulk-target",
|
||||||
|
},
|
||||||
|
"schedulable": true,
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, float64(0), resp["code"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
|
|||||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
|
||||||
|
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
|
||||||
|
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
|
||||||
|
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
|
||||||
|
t.Run("nil input returns nil", func(t *testing.T) {
|
||||||
|
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
|
||||||
|
in := &dto.OpenAIFastPolicySettings{
|
||||||
|
Rules: []dto.OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: "",
|
||||||
|
Action: "filter",
|
||||||
|
Scope: "all",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
out := openaiFastPolicySettingsFromDTO(in)
|
||||||
|
require.NotNil(t, out)
|
||||||
|
require.Len(t, out.Rules, 1)
|
||||||
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||||
|
require.Equal(t, "all", out.Rules[0].ServiceTier)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
|
||||||
|
in := &dto.OpenAIFastPolicySettings{
|
||||||
|
Rules: []dto.OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: " ",
|
||||||
|
Action: "pass",
|
||||||
|
Scope: "all",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
out := openaiFastPolicySettingsFromDTO(in)
|
||||||
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
|
||||||
|
in := &dto.OpenAIFastPolicySettings{
|
||||||
|
Rules: []dto.OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: "PRIORITY",
|
||||||
|
Action: "filter",
|
||||||
|
Scope: "all",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
out := openaiFastPolicySettingsFromDTO(in)
|
||||||
|
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
|
||||||
|
in := &dto.OpenAIFastPolicySettings{
|
||||||
|
Rules: []dto.OpenAIFastPolicyRule{
|
||||||
|
{ServiceTier: "priority", Action: "filter", Scope: "all"},
|
||||||
|
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
|
||||||
|
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := openaiFastPolicySettingsFromDTO(in)
|
||||||
|
require.Len(t, out.Rules, 3)
|
||||||
|
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||||
|
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
|
||||||
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
|||||||
return nil, service.ErrAPIKeyNotFound
|
return nil, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
|
||||||
|
for i := range s.apiKeys {
|
||||||
|
if s.apiKeys[i].ID == keyID {
|
||||||
|
s.apiKeys[i].Usage5h = 0
|
||||||
|
s.apiKeys[i].Usage1d = 0
|
||||||
|
s.apiKeys[i].Usage7d = 0
|
||||||
|
s.apiKeys[i].Window5hStart = nil
|
||||||
|
s.apiKeys[i].Window1dStart = nil
|
||||||
|
s.apiKeys[i].Window7dStart = nil
|
||||||
|
k := s.apiKeys[i]
|
||||||
|
return &k, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
|
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
|
||||||
type AdminUpdateAPIKeyGroupRequest struct {
|
type AdminUpdateAPIKeyGroupRequest struct {
|
||||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||||
|
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroup handles updating an API key's group binding
|
// UpdateGroup handles updating an API key's admin-managed fields.
|
||||||
// PUT /api/v1/admin/api-keys/:id
|
// PUT /api/v1/admin/api-keys/:id
|
||||||
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
@@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var resetKey *service.APIKey
|
||||||
|
if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
|
||||||
|
resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
|
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if resetKey != nil && req.GroupID == nil {
|
||||||
|
result.APIKey = resetKey
|
||||||
|
}
|
||||||
|
|
||||||
resp := struct {
|
resp := struct {
|
||||||
APIKey *dto.APIKey `json:"api_key"`
|
APIKey *dto.APIKey `json:"api_key"`
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
|
|||||||
require.Nil(t, resp.Data.APIKey.GroupID)
|
require.Nil(t, resp.Data.APIKey.GroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
|
||||||
|
svc := newStubAdminService()
|
||||||
|
now := time.Now()
|
||||||
|
svc.apiKeys[0].Usage5h = 1.2
|
||||||
|
svc.apiKeys[0].Usage1d = 3.4
|
||||||
|
svc.apiKeys[0].Usage7d = 5.6
|
||||||
|
svc.apiKeys[0].Window5hStart = &now
|
||||||
|
svc.apiKeys[0].Window1dStart = &now
|
||||||
|
svc.apiKeys[0].Window7dStart = &now
|
||||||
|
router := setupAPIKeyHandler(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data struct {
|
||||||
|
APIKey struct {
|
||||||
|
Usage5h float64 `json:"usage_5h"`
|
||||||
|
Usage1d float64 `json:"usage_1d"`
|
||||||
|
Usage7d float64 `json:"usage_7d"`
|
||||||
|
Window5hStart *time.Time `json:"window_5h_start"`
|
||||||
|
Window1dStart *time.Time `json:"window_1d_start"`
|
||||||
|
Window7dStart *time.Time `json:"window_7d_start"`
|
||||||
|
} `json:"api_key"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Zero(t, resp.Data.APIKey.Usage5h)
|
||||||
|
require.Zero(t, resp.Data.APIKey.Usage1d)
|
||||||
|
require.Zero(t, resp.Data.APIKey.Usage7d)
|
||||||
|
require.Nil(t, resp.Data.APIKey.Window5hStart)
|
||||||
|
require.Nil(t, resp.Data.APIKey.Window1dStart)
|
||||||
|
require.Nil(t, resp.Data.APIKey.Window7dStart)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
|
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
|
||||||
svc := &failingUpdateGroupService{
|
svc := &failingUpdateGroupService{
|
||||||
stubAdminService: newStubAdminService(),
|
stubAdminService: newStubAdminService(),
|
||||||
|
|||||||
@@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
|
|
||||||
AffiliateEnabled: settings.AffiliateEnabled,
|
AffiliateEnabled: settings.AffiliateEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI fast policy (stored under a dedicated setting key)
|
||||||
|
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||||
|
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||||
|
} else if fastPolicy != nil {
|
||||||
|
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
|
||||||
|
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
|
||||||
|
for i, r := range s.Rules {
|
||||||
|
rules[i] = dto.OpenAIFastPolicyRule(r)
|
||||||
|
}
|
||||||
|
return &dto.OpenAIFastPolicySettings{Rules: rules}
|
||||||
|
}
|
||||||
|
|
||||||
|
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
|
||||||
|
//
|
||||||
|
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
|
||||||
|
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
|
||||||
|
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
|
||||||
|
// service.SetOpenAIFastPolicySettings 负责合法值校验。
|
||||||
|
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
|
||||||
|
for i, r := range s.Rules {
|
||||||
|
rules[i] = service.OpenAIFastPolicyRule(r)
|
||||||
|
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
|
||||||
|
if tier == "" {
|
||||||
|
tier = service.OpenAIFastTierAny
|
||||||
|
}
|
||||||
|
rules[i].ServiceTier = tier
|
||||||
|
}
|
||||||
|
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSettingsRequest 更新设置请求
|
// UpdateSettingsRequest 更新设置请求
|
||||||
type UpdateSettingsRequest struct {
|
type UpdateSettingsRequest struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
@@ -452,6 +494,9 @@ type UpdateSettingsRequest struct {
|
|||||||
|
|
||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
|
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||||
|
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
@@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update OpenAI fast policy (stored under dedicated key, only when provided).
|
||||||
|
if req.OpenAIFastPolicySettings != nil {
|
||||||
|
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update payment configuration (integrated into system settings).
|
// Update payment configuration (integrated into system settings).
|
||||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||||
@@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
|
|
||||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||||
}
|
}
|
||||||
|
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||||
|
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||||
|
} else if fastPolicy != nil {
|
||||||
|
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||||
|
}
|
||||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
panic("unexpected GetValue call")
|
if s.values != nil {
|
||||||
|
if value, ok := s.values[key]; ok {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
|||||||
@@ -198,6 +198,9 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
|
// OpenAI fast/flex policy
|
||||||
|
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
@@ -294,6 +297,22 @@ type BetaPolicySettings struct {
|
|||||||
Rules []BetaPolicyRule `json:"rules"`
|
Rules []BetaPolicyRule `json:"rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
|
||||||
|
type OpenAIFastPolicyRule struct {
|
||||||
|
ServiceTier string `json:"service_tier"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
ModelWhitelist []string `json:"model_whitelist,omitempty"`
|
||||||
|
FallbackAction string `json:"fallback_action,omitempty"`
|
||||||
|
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
|
||||||
|
type OpenAIFastPolicySettings struct {
|
||||||
|
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||||
// Returns empty slice on empty/invalid input.
|
// Returns empty slice on empty/invalid input.
|
||||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
|
|||||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionHash := ""
|
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
|
||||||
if parsed.Multipart {
|
|
||||||
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
|
||||||
} else {
|
|
||||||
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
|
|||||||
@@ -258,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
|||||||
assert.Equal(t, "tool_use", anth.Content[1].Type)
|
assert.Equal(t, "tool_use", anth.Content[1].Type)
|
||||||
assert.Equal(t, "call_1", anth.Content[1].ID)
|
assert.Equal(t, "call_1", anth.Content[1].ID)
|
||||||
assert.Equal(t, "get_weather", anth.Content[1].Name)
|
assert.Equal(t, "get_weather", anth.Content[1].Name)
|
||||||
|
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_read",
|
||||||
|
Model: "gpt-5.5",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_read",
|
||||||
|
Name: "Read",
|
||||||
|
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||||
|
require.Len(t, anth.Content, 1)
|
||||||
|
assert.Equal(t, "tool_use", anth.Content[0].Type)
|
||||||
|
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_other",
|
||||||
|
Model: "gpt-5.5",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "function_call",
|
||||||
|
CallID: "call_other",
|
||||||
|
Name: "Search",
|
||||||
|
Arguments: `{"query":""}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||||
|
require.Len(t, anth.Content, 1)
|
||||||
|
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
|
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
|
||||||
@@ -472,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
|
|||||||
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
|
||||||
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
|
||||||
|
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.created",
|
||||||
|
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
|
||||||
|
}, state)
|
||||||
|
|
||||||
|
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_item.added",
|
||||||
|
OutputIndex: 0,
|
||||||
|
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
assert.Equal(t, "content_block_start", events[0].Type)
|
||||||
|
|
||||||
|
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.delta",
|
||||||
|
OutputIndex: 0,
|
||||||
|
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||||
|
}, state)
|
||||||
|
assert.Len(t, events, 0)
|
||||||
|
|
||||||
|
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.function_call_arguments.done",
|
||||||
|
OutputIndex: 0,
|
||||||
|
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||||
|
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
|
||||||
|
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
|
||||||
|
assert.Equal(t, "content_block_stop", events[1].Type)
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamingReasoning(t *testing.T) {
|
func TestStreamingReasoning(t *testing.T) {
|
||||||
state := NewResponsesEventToAnthropicState()
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
|
||||||
@@ -914,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
|||||||
var tc map[string]any
|
var tc map[string]any
|
||||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
assert.Equal(t, "function", tc["type"])
|
assert.Equal(t, "function", tc["type"])
|
||||||
fn, ok := tc["function"].(map[string]any)
|
assert.Equal(t, "get_weather", tc["name"])
|
||||||
require.True(t, ok)
|
assert.NotContains(t, tc, "function")
|
||||||
assert.Equal(t, "get_weather", fn["name"])
|
}
|
||||||
|
|
||||||
|
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
|
||||||
|
req := &ResponsesRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||||
|
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ResponsesToAnthropicRequest(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tc map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
|
assert.Equal(t, "tool", tc["type"])
|
||||||
|
assert.Equal(t, "get_weather", tc["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
|
||||||
|
req := &ResponsesRequest{
|
||||||
|
Model: "gpt-5.2",
|
||||||
|
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||||
|
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ResponsesToAnthropicRequest(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tc map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
|
assert.Equal(t, "tool", tc["type"])
|
||||||
|
assert.Equal(t, "get_weather", tc["name"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
|||||||
// {"type":"auto"} → "auto"
|
// {"type":"auto"} → "auto"
|
||||||
// {"type":"any"} → "required"
|
// {"type":"any"} → "required"
|
||||||
// {"type":"none"} → "none"
|
// {"type":"none"} → "none"
|
||||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
|
||||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
var tc struct {
|
var tc struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
|
|||||||
return json.Marshal("none")
|
return json.Marshal("none")
|
||||||
case "tool":
|
case "tool":
|
||||||
return json.Marshal(map[string]any{
|
return json.Marshal(map[string]any{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": map[string]string{"name": tc.Name},
|
"name": tc.Name,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
// Pass through unknown types as-is
|
// Pass through unknown types as-is
|
||||||
|
|||||||
@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
|||||||
var tc map[string]any
|
var tc map[string]any
|
||||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||||
assert.Equal(t, "function", tc["type"])
|
assert.Equal(t, "function", tc["type"])
|
||||||
|
assert.Equal(t, "get_weather", tc["name"])
|
||||||
|
assert.NotContains(t, tc, "function")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||||
|
|||||||
@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
|
|||||||
//
|
//
|
||||||
// "auto" → "auto"
|
// "auto" → "auto"
|
||||||
// "none" → "none"
|
// "none" → "none"
|
||||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
// {"name":"X"} → {"type":"function","name":"X"}
|
||||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||||
var s string
|
var s string
|
||||||
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return json.Marshal(map[string]any{
|
return json.Marshal(map[string]any{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": map[string]string{"name": obj.Name},
|
"name": obj.Name,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
|||||||
Type: "tool_use",
|
Type: "tool_use",
|
||||||
ID: fromResponsesCallID(item.CallID),
|
ID: fromResponsesCallID(item.CallID),
|
||||||
Name: item.Name,
|
Name: item.Name,
|
||||||
Input: json.RawMessage(item.Arguments),
|
Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
|
||||||
})
|
})
|
||||||
case "web_search_call":
|
case "web_search_call":
|
||||||
toolUseID := "srvtoolu_" + item.ID
|
toolUseID := "srvtoolu_" + item.ID
|
||||||
@@ -129,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
|
||||||
|
if name != "Read" || raw == "" {
|
||||||
|
return json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
var input map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(raw), &input); err != nil {
|
||||||
|
return json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
|
||||||
|
return json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(input, "pages")
|
||||||
|
sanitized, err := json.Marshal(input)
|
||||||
|
if err != nil {
|
||||||
|
return json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -142,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
|
|||||||
ContentBlockIndex int
|
ContentBlockIndex int
|
||||||
ContentBlockOpen bool
|
ContentBlockOpen bool
|
||||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||||
|
CurrentToolName string
|
||||||
|
CurrentToolArgs string
|
||||||
|
|
||||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||||
OutputIndexToBlockIdx map[int]int
|
OutputIndexToBlockIdx map[int]int
|
||||||
@@ -181,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
|
|||||||
case "response.function_call_arguments.delta":
|
case "response.function_call_arguments.delta":
|
||||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||||
case "response.function_call_arguments.done":
|
case "response.function_call_arguments.done":
|
||||||
return resToAnthHandleBlockDone(state)
|
return resToAnthHandleFuncArgsDone(evt, state)
|
||||||
case "response.output_item.done":
|
case "response.output_item.done":
|
||||||
return resToAnthHandleOutputItemDone(evt, state)
|
return resToAnthHandleOutputItemDone(evt, state)
|
||||||
case "response.reasoning_summary_text.delta":
|
case "response.reasoning_summary_text.delta":
|
||||||
@@ -278,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
|
|||||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||||
state.ContentBlockOpen = true
|
state.ContentBlockOpen = true
|
||||||
state.CurrentBlockType = "tool_use"
|
state.CurrentBlockType = "tool_use"
|
||||||
|
state.CurrentToolName = evt.Item.Name
|
||||||
|
state.CurrentToolArgs = ""
|
||||||
|
|
||||||
events = append(events, AnthropicStreamEvent{
|
events = append(events, AnthropicStreamEvent{
|
||||||
Type: "content_block_start",
|
Type: "content_block_start",
|
||||||
@@ -358,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
|
||||||
|
state.CurrentToolArgs += evt.Delta
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -373,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
|||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||||
|
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
|
||||||
|
return resToAnthHandleBlockDone(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := evt.Arguments
|
||||||
|
if raw == "" {
|
||||||
|
raw = state.CurrentToolArgs
|
||||||
|
}
|
||||||
|
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
|
||||||
|
if len(sanitized) == 0 {
|
||||||
|
return closeCurrentBlock(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := state.ContentBlockIndex
|
||||||
|
events := []AnthropicStreamEvent{{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: &idx,
|
||||||
|
Delta: &AnthropicDelta{
|
||||||
|
Type: "input_json_delta",
|
||||||
|
PartialJSON: string(sanitized),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
events = append(events, closeCurrentBlock(state)...)
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||||
if evt.Delta == "" {
|
if evt.Delta == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -524,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
|
|||||||
idx := state.ContentBlockIndex
|
idx := state.ContentBlockIndex
|
||||||
state.ContentBlockOpen = false
|
state.ContentBlockOpen = false
|
||||||
state.ContentBlockIndex++
|
state.ContentBlockIndex++
|
||||||
|
state.CurrentToolName = ""
|
||||||
|
state.CurrentToolArgs = ""
|
||||||
return []AnthropicStreamEvent{{
|
return []AnthropicStreamEvent{{
|
||||||
Type: "content_block_stop",
|
Type: "content_block_stop",
|
||||||
Index: &idx,
|
Index: &idx,
|
||||||
|
|||||||
@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
|
|||||||
// "auto" → {"type":"auto"}
|
// "auto" → {"type":"auto"}
|
||||||
// "required" → {"type":"any"}
|
// "required" → {"type":"any"}
|
||||||
// "none" → {"type":"none"}
|
// "none" → {"type":"none"}
|
||||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
|
// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
|
||||||
|
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
|
||||||
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||||
// Try as string first
|
// Try as string first
|
||||||
var s string
|
var s string
|
||||||
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
|
|||||||
// Try as object with type=function
|
// Try as object with type=function
|
||||||
var tc struct {
|
var tc struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
Function struct {
|
Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
} `json:"function"`
|
} `json:"function"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
|
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
|
||||||
|
name := strings.TrimSpace(tc.Name)
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(tc.Function.Name)
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
return json.Marshal(map[string]string{
|
return json.Marshal(map[string]string{
|
||||||
"type": "tool",
|
"type": "tool",
|
||||||
"name": tc.Function.Name,
|
"name": name,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,28 @@ package httputil
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"compress/zlib"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
requestBodyReadInitCap = 512
|
requestBodyReadInitCap = 512
|
||||||
requestBodyReadMaxInitCap = 1 << 20
|
requestBodyReadMaxInitCap = 1 << 20
|
||||||
|
// maxDecompressedBodySize limits the decompressed request body to 64 MB
|
||||||
|
// to prevent decompression bomb attacks.
|
||||||
|
maxDecompressedBodySize = 64 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
|
||||||
|
// on content length, transparently decoding any Content-Encoding the upstream
|
||||||
|
// client used to compress the body (zstd, gzip, deflate).
|
||||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||||
if req == nil || req.Body == nil {
|
if req == nil || req.Body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
|||||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return buf.Bytes(), nil
|
raw := buf.Bytes()
|
||||||
|
|
||||||
|
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
|
||||||
|
if enc == "" || enc == "identity" {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := decompressRequestBody(enc, raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Del("Content-Encoding")
|
||||||
|
req.Header.Del("Content-Length")
|
||||||
|
req.ContentLength = int64(len(decoded))
|
||||||
|
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
|
||||||
|
switch encoding {
|
||||||
|
case "zstd":
|
||||||
|
dec, err := zstd.NewReader(bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer dec.Close()
|
||||||
|
return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
|
||||||
|
case "gzip", "x-gzip":
|
||||||
|
gr, err := gzip.NewReader(bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = gr.Close() }()
|
||||||
|
return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
|
||||||
|
case "deflate":
|
||||||
|
zr, err := zlib.NewReader(bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = zr.Close() }()
|
||||||
|
return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported Content-Encoding")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
143
backend/internal/pkg/httputil/body_test.go
Normal file
143
backend/internal/pkg/httputil/body_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"compress/zlib"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
|
)
|
||||||
|
|
||||||
|
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
|
||||||
|
|
||||||
|
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest: %v", err)
|
||||||
|
}
|
||||||
|
if encoding != "" {
|
||||||
|
req.Header.Set("Content-Encoding", encoding)
|
||||||
|
}
|
||||||
|
req.ContentLength = int64(len(body))
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
|
||||||
|
req := newRequestWithBody(t, []byte(samplePayload), "")
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != samplePayload {
|
||||||
|
t.Fatalf("body mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
|
||||||
|
enc, _ := zstd.NewWriter(nil)
|
||||||
|
compressed := enc.EncodeAll([]byte(samplePayload), nil)
|
||||||
|
_ = enc.Close()
|
||||||
|
|
||||||
|
req := newRequestWithBody(t, compressed, "zstd")
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != samplePayload {
|
||||||
|
t.Fatalf("body mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
if req.Header.Get("Content-Encoding") != "" {
|
||||||
|
t.Fatalf("Content-Encoding should be cleared after decoding")
|
||||||
|
}
|
||||||
|
if req.ContentLength != int64(len(samplePayload)) {
|
||||||
|
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&buf)
|
||||||
|
if _, err := gw.Write([]byte(samplePayload)); err != nil {
|
||||||
|
t.Fatalf("gzip write: %v", err)
|
||||||
|
}
|
||||||
|
if err := gw.Close(); err != nil {
|
||||||
|
t.Fatalf("gzip close: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newRequestWithBody(t, buf.Bytes(), "gzip")
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != samplePayload {
|
||||||
|
t.Fatalf("body mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
zw := zlib.NewWriter(&buf)
|
||||||
|
if _, err := zw.Write([]byte(samplePayload)); err != nil {
|
||||||
|
t.Fatalf("zlib write: %v", err)
|
||||||
|
}
|
||||||
|
if err := zw.Close(); err != nil {
|
||||||
|
t.Fatalf("zlib close: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newRequestWithBody(t, buf.Bytes(), "deflate")
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != samplePayload {
|
||||||
|
t.Fatalf("body mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
|
||||||
|
req := newRequestWithBody(t, []byte(samplePayload), "br")
|
||||||
|
_, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unsupported encoding, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "br") {
|
||||||
|
t.Fatalf("error should mention encoding, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
|
||||||
|
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
|
||||||
|
_, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for corrupt zstd body, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest: %v", err)
|
||||||
|
}
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected nil body, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
|
||||||
|
req := newRequestWithBody(t, []byte(samplePayload), "identity")
|
||||||
|
got, err := ReadRequestBodyWithPrealloc(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != samplePayload {
|
||||||
|
t.Fatalf("body mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,49 @@ const (
|
|||||||
|
|
||||||
defaultSchedulerSnapshotMGetChunkSize = 128
|
defaultSchedulerSnapshotMGetChunkSize = 128
|
||||||
defaultSchedulerSnapshotWriteChunkSize = 256
|
defaultSchedulerSnapshotWriteChunkSize = 256
|
||||||
|
|
||||||
|
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
|
||||||
|
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
|
||||||
|
snapshotGraceTTLSeconds = 60
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// activateSnapshotScript 原子 CAS 切换快照版本。
|
||||||
|
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
|
||||||
|
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
|
||||||
|
//
|
||||||
|
// KEYS[1] = activeKey (sched:active:{bucket})
|
||||||
|
// KEYS[2] = readyKey (sched:ready:{bucket})
|
||||||
|
// KEYS[3] = bucketSetKey (sched:buckets)
|
||||||
|
// KEYS[4] = snapshotKey (新写入的快照 key)
|
||||||
|
// ARGV[1] = 新版本号字符串
|
||||||
|
// ARGV[2] = bucket 字符串 (用于 SADD)
|
||||||
|
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
|
||||||
|
// ARGV[4] = 宽限期 TTL 秒数
|
||||||
|
//
|
||||||
|
// 返回 1 = 已激活, 0 = 版本过旧未激活
|
||||||
|
activateSnapshotScript = redis.NewScript(`
|
||||||
|
local currentActive = redis.call('GET', KEYS[1])
|
||||||
|
local newVersion = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
if currentActive ~= false then
|
||||||
|
local curVersion = tonumber(currentActive)
|
||||||
|
if curVersion and newVersion < curVersion then
|
||||||
|
redis.call('DEL', KEYS[4])
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('SET', KEYS[1], ARGV[1])
|
||||||
|
redis.call('SET', KEYS[2], '1')
|
||||||
|
redis.call('SADD', KEYS[3], ARGV[2])
|
||||||
|
|
||||||
|
if currentActive ~= false and currentActive ~= ARGV[1] then
|
||||||
|
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
|
||||||
|
end
|
||||||
|
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type schedulerCache struct {
|
type schedulerCache struct {
|
||||||
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
// Phase 1: 分配新版本号并写入快照数据。
|
||||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
// INCR 保证每个调用方获得唯一递增版本号。
|
||||||
|
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
|
||||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pipe := c.rdb.Pipeline()
|
|
||||||
if len(accounts) > 0 {
|
if len(accounts) > 0 {
|
||||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||||
members := make([]redis.Z, 0, len(accounts))
|
members := make([]redis.Z, 0, len(accounts))
|
||||||
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
|||||||
Member: strconv.FormatInt(account.ID, 10),
|
Member: strconv.FormatInt(account.ID, 10),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
for start := 0; start < len(members); start += c.writeChunkSize {
|
for start := 0; start < len(members); start += c.writeChunkSize {
|
||||||
end := start + c.writeChunkSize
|
end := start + c.writeChunkSize
|
||||||
if end > len(members) {
|
if end > len(members) {
|
||||||
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
|||||||
}
|
}
|
||||||
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
|
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
|
||||||
}
|
}
|
||||||
} else {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
pipe.Del(ctx, snapshotKey)
|
return err
|
||||||
}
|
}
|
||||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
|
||||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
|
||||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldActive != "" && oldActive != versionStr {
|
// Phase 2: 原子 CAS 激活版本。
|
||||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
|
||||||
|
// 防止并发写入导致版本回滚。
|
||||||
|
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
|
||||||
|
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||||
|
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||||
|
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||||
|
|
||||||
|
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
|
||||||
|
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
|
||||||
|
|
||||||
|
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
|
|||||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||||
|
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_alipay_enabled": true,
|
"payment_visible_method_alipay_enabled": true,
|
||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": true,
|
"openai_advanced_scheduler_enabled": true,
|
||||||
|
"openai_fast_policy_settings": {
|
||||||
|
"rules": [
|
||||||
|
{
|
||||||
|
"service_tier": "priority",
|
||||||
|
"action": "filter",
|
||||||
|
"scope": "all",
|
||||||
|
"fallback_action": "pass"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"custom_menu_items": [],
|
"custom_menu_items": [],
|
||||||
"custom_endpoints": [],
|
"custom_endpoints": [],
|
||||||
"payment_enabled": false,
|
"payment_enabled": false,
|
||||||
@@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_alipay_enabled": false,
|
"payment_visible_method_alipay_enabled": false,
|
||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": false,
|
"openai_advanced_scheduler_enabled": false,
|
||||||
|
"openai_fast_policy_settings": {
|
||||||
|
"rules": [
|
||||||
|
{
|
||||||
|
"service_tier": "priority",
|
||||||
|
"action": "filter",
|
||||||
|
"scope": "all",
|
||||||
|
"fallback_action": "pass"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"payment_enabled": false,
|
"payment_enabled": false,
|
||||||
"payment_min_amount": 0,
|
"payment_min_amount": 0,
|
||||||
"payment_max_amount": 0,
|
"payment_max_amount": 0,
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool {
|
|||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
geminiTokenProvider *GeminiTokenProvider
|
geminiTokenProvider *GeminiTokenProvider
|
||||||
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
antigravityGatewayService *AntigravityGatewayService
|
antigravityGatewayService *AntigravityGatewayService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -74,6 +75,7 @@ type AccountTestService struct {
|
|||||||
func NewAccountTestService(
|
func NewAccountTestService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
geminiTokenProvider *GeminiTokenProvider,
|
geminiTokenProvider *GeminiTokenProvider,
|
||||||
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
antigravityGatewayService *AntigravityGatewayService,
|
antigravityGatewayService *AntigravityGatewayService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -82,6 +84,7 @@ func NewAccountTestService(
|
|||||||
return &AccountTestService{
|
return &AccountTestService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
geminiTokenProvider: geminiTokenProvider,
|
geminiTokenProvider: geminiTokenProvider,
|
||||||
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
antigravityGatewayService: antigravityGatewayService,
|
antigravityGatewayService: antigravityGatewayService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
if account.IsBedrock() {
|
if account.IsBedrock() {
|
||||||
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||||
}
|
}
|
||||||
|
if account.Type == AccountTypeServiceAccount {
|
||||||
|
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
|
||||||
|
}
|
||||||
|
|
||||||
// Determine authentication method and API URL
|
// Determine authentication method and API URL
|
||||||
var authToken string
|
var authToken string
|
||||||
@@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
return s.processClaudeStream(c, resp.Body)
|
return s.processClaudeStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||||
|
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
|
||||||
|
testModelID = mappedModel
|
||||||
|
} else {
|
||||||
|
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
payload, err := createTestPayload(testModelID)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||||
|
}
|
||||||
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.claudeTokenProvider == nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Claude token provider not configured")
|
||||||
|
}
|
||||||
|
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||||
|
}
|
||||||
|
return s.sendErrorAndEnd(c, errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.processClaudeStream(c, resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||||
region := bedrockRuntimeRegion(account)
|
region := bedrockRuntimeRegion(account)
|
||||||
@@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
testModelID = geminicli.DefaultTestModel
|
testModelID = geminicli.DefaultTestModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts with model mapping, map the model
|
// For static upstream credentials with model mapping, map the model
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if len(mapping) > 0 {
|
if len(mapping) > 0 {
|
||||||
if mappedModel, exists := mapping[testModelID]; exists {
|
if mappedModel, exists := mapping[testModelID]; exists {
|
||||||
@@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||||
case AccountTypeOAuth:
|
case AccountTypeOAuth:
|
||||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||||
|
case AccountTypeServiceAccount:
|
||||||
|
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
|
||||||
default:
|
default:
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||||
}
|
}
|
||||||
@@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
|||||||
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
|
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||||
|
if s.geminiTokenProvider == nil {
|
||||||
|
return nil, fmt.Errorf("gemini token provider not configured")
|
||||||
|
}
|
||||||
|
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service account access token: %w", err)
|
||||||
|
}
|
||||||
|
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
|
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
|
||||||
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
|
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
|
||||||
var inner map[string]any
|
var inner map[string]any
|
||||||
@@ -1227,7 +1324,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||||
}
|
}
|
||||||
apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
|
apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
|
||||||
|
|
||||||
// Set SSE headers
|
// Set SSE headers
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
|
|||||||
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||||
require.Contains(t, rec.Body.String(), "\"success\":true")
|
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
httpUpstream: upstream,
|
||||||
|
cfg: &config.Config{},
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 54,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||||
|
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||||
|
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -58,6 +59,7 @@ type AdminService interface {
|
|||||||
|
|
||||||
// API Key management (admin)
|
// API Key management (admin)
|
||||||
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
|
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
|
||||||
|
AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
|
||||||
|
|
||||||
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
|
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
|
||||||
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
|
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
|
||||||
@@ -291,6 +293,7 @@ type UpdateAccountInput struct {
|
|||||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||||
type BulkUpdateAccountsInput struct {
|
type BulkUpdateAccountsInput struct {
|
||||||
AccountIDs []int64
|
AccountIDs []int64
|
||||||
|
Filters *BulkUpdateAccountFilters
|
||||||
Name string
|
Name string
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int
|
Concurrency *int
|
||||||
@@ -307,6 +310,15 @@ type BulkUpdateAccountsInput struct {
|
|||||||
SkipMixedChannelCheck bool
|
SkipMixedChannelCheck bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BulkUpdateAccountFilters struct {
|
||||||
|
Platform string
|
||||||
|
Type string
|
||||||
|
Status string
|
||||||
|
Group string
|
||||||
|
Search string
|
||||||
|
PrivacyMode string
|
||||||
|
}
|
||||||
|
|
||||||
// BulkUpdateAccountResult captures the result for a single account update.
|
// BulkUpdateAccountResult captures the result for a single account update.
|
||||||
type BulkUpdateAccountResult struct {
|
type BulkUpdateAccountResult struct {
|
||||||
AccountID int64 `json:"account_id"`
|
AccountID int64 `json:"account_id"`
|
||||||
@@ -1961,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
|
||||||
|
func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
apiKey.Usage5h = 0
|
||||||
|
apiKey.Usage1d = 0
|
||||||
|
apiKey.Usage7d = 0
|
||||||
|
apiKey.Window5hStart = nil
|
||||||
|
apiKey.Window1dStart = nil
|
||||||
|
apiKey.Window7dStart = nil
|
||||||
|
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
|
||||||
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
}
|
||||||
|
if s.billingCacheService != nil {
|
||||||
|
_ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReplaceUserGroup 替换用户的专属分组
|
// ReplaceUserGroup 替换用户的专属分组
|
||||||
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
|
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
|
||||||
if oldGroupID == newGroupID {
|
if oldGroupID == newGroupID {
|
||||||
@@ -2286,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||||
|
if len(input.AccountIDs) == 0 && input.Filters != nil {
|
||||||
|
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
input.AccountIDs = accountIDs
|
||||||
|
}
|
||||||
|
|
||||||
result := &BulkUpdateAccountsResult{
|
result := &BulkUpdateAccountsResult{
|
||||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
@@ -2401,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
|
||||||
|
if filters == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groupID := int64(0)
|
||||||
|
switch strings.TrimSpace(filters.Group) {
|
||||||
|
case "":
|
||||||
|
case "ungrouped":
|
||||||
|
groupID = AccountListGroupUngrouped
|
||||||
|
default:
|
||||||
|
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid group filter: %w", err)
|
||||||
|
}
|
||||||
|
groupID = parsedGroupID
|
||||||
|
}
|
||||||
|
|
||||||
|
const pageSize = 500
|
||||||
|
page := 1
|
||||||
|
accountIDs := make([]int64, 0, pageSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
accounts, total, err := s.ListAccounts(
|
||||||
|
ctx,
|
||||||
|
page,
|
||||||
|
pageSize,
|
||||||
|
filters.Platform,
|
||||||
|
filters.Type,
|
||||||
|
filters.Status,
|
||||||
|
filters.Search,
|
||||||
|
groupID,
|
||||||
|
filters.PrivacyMode,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, account := range accounts {
|
||||||
|
accountIDs = append(accountIDs, account.ID)
|
||||||
|
}
|
||||||
|
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
|
||||||
|
return accountIDs, nil
|
||||||
|
}
|
||||||
|
page++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
|
|||||||
getByIDCalled []int64
|
getByIDCalled []int64
|
||||||
listByGroupData map[int64][]Account
|
listByGroupData map[int64][]Account
|
||||||
listByGroupErr map[int64]error
|
listByGroupErr map[int64]error
|
||||||
|
listData []Account
|
||||||
|
listResult *pagination.PaginationResult
|
||||||
|
listErr error
|
||||||
|
listCalled bool
|
||||||
|
lastListParams pagination.PaginationParams
|
||||||
|
lastListFilters struct {
|
||||||
|
platform string
|
||||||
|
accountType string
|
||||||
|
status string
|
||||||
|
search string
|
||||||
|
groupID int64
|
||||||
|
privacyMode string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||||
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||||
|
s.listCalled = true
|
||||||
|
s.lastListParams = params
|
||||||
|
s.lastListFilters.platform = platform
|
||||||
|
s.lastListFilters.accountType = accountType
|
||||||
|
s.lastListFilters.status = status
|
||||||
|
s.lastListFilters.search = search
|
||||||
|
s.lastListFilters.groupID = groupID
|
||||||
|
s.lastListFilters.privacyMode = privacyMode
|
||||||
|
if s.listErr != nil {
|
||||||
|
return nil, nil, s.listErr
|
||||||
|
}
|
||||||
|
if s.listResult != nil {
|
||||||
|
return s.listData, s.listResult, nil
|
||||||
|
}
|
||||||
|
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||||
repo := &accountRepoStubForBulkUpdate{}
|
repo := &accountRepoStubForBulkUpdate{}
|
||||||
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
|
|||||||
// No BindGroups should have been called since the check runs before any write.
|
// No BindGroups should have been called since the check runs before any write.
|
||||||
require.Empty(t, repo.bindGroupsCalls)
|
require.Empty(t, repo.bindGroupsCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForBulkUpdate{
|
||||||
|
listData: []Account{
|
||||||
|
{ID: 7},
|
||||||
|
{ID: 11},
|
||||||
|
},
|
||||||
|
listResult: &pagination.PaginationResult{Total: 2},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
schedulable := true
|
||||||
|
input := &BulkUpdateAccountsInput{
|
||||||
|
Schedulable: &schedulable,
|
||||||
|
}
|
||||||
|
|
||||||
|
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
|
||||||
|
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
|
||||||
|
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
|
||||||
|
|
||||||
|
filtersValue := reflect.New(filtersField.Type().Elem())
|
||||||
|
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
|
||||||
|
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
|
||||||
|
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
|
||||||
|
filtersValue.Elem().FieldByName("Group").SetString("12")
|
||||||
|
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
|
||||||
|
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
|
||||||
|
filtersField.Set(filtersValue)
|
||||||
|
|
||||||
|
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
|
||||||
|
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
|
||||||
|
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
|
||||||
|
require.Equal(t, StatusActive, repo.lastListFilters.status)
|
||||||
|
require.Equal(t, "bulk-target", repo.lastListFilters.search)
|
||||||
|
require.Equal(t, int64(12), repo.lastListFilters.groupID)
|
||||||
|
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
|
||||||
|
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
|
||||||
|
require.Equal(t, 2, result.Success)
|
||||||
|
require.Equal(t, 0, result.Failed)
|
||||||
|
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
|
||||||
|
}
|
||||||
|
|||||||
@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
|
||||||
|
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||||
|
if s.cache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
|
||||||
|
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// API Key 限速缓存方法
|
// API Key 限速缓存方法
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ const (
|
|||||||
// ClaudeTokenCache token cache interface.
|
// ClaudeTokenCache token cache interface.
|
||||||
type ClaudeTokenCache = GeminiTokenCache
|
type ClaudeTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
|
||||||
type ClaudeTokenProvider struct {
|
type ClaudeTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache ClaudeTokenCache
|
tokenCache ClaudeTokenCache
|
||||||
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||||
return "", errors.New("not an anthropic oauth account")
|
return "", errors.New("not an anthropic oauth or service account")
|
||||||
|
}
|
||||||
|
if account.Type == AccountTypeServiceAccount {
|
||||||
|
return p.getServiceAccountAccessToken(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := ClaudeTokenCacheKey(account)
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||||
|
}
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
|
|||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||||
return "", errors.New("not an anthropic oauth account")
|
return "", errors.New("not an anthropic oauth or service account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := ClaudeTokenCacheKey(account)
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
|||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
|||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
|||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,11 +41,12 @@ const (
|
|||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
const (
|
const (
|
||||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
|
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -306,6 +307,12 @@ const (
|
|||||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||||
|
|
||||||
|
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
|
||||||
|
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
|
||||||
|
// targets OpenAI's body-level service_tier field instead of Claude's
|
||||||
|
// anthropic-beta header.
|
||||||
|
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
|
||||||
|
|
||||||
// =========================
|
// =========================
|
||||||
// Claude Code Version Check
|
// Claude Code Version Check
|
||||||
// =========================
|
// =========================
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer inbound-token")
|
||||||
|
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
|
||||||
|
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeServiceAccount,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"project_id": "vertex-proj",
|
||||||
|
"location": "us-east5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
|
||||||
|
svc := &GatewayService{}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(),
|
||||||
|
c,
|
||||||
|
account,
|
||||||
|
body,
|
||||||
|
"vertex-token",
|
||||||
|
"service_account",
|
||||||
|
"claude-sonnet-4-5@20250929",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
|
||||||
|
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
|
||||||
|
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
|
||||||
|
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
|
||||||
|
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
|
||||||
|
|
||||||
|
got := readRequestBodyForTest(t, req)
|
||||||
|
require.Equal(t, "", gjson.GetBytes(got, "model").String())
|
||||||
|
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
|
||||||
|
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
|
||||||
|
t.Helper()
|
||||||
|
require.NotNil(t, req.Body)
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
|||||||
|
|
||||||
// 4. Model mapping
|
// 4. Model mapping
|
||||||
mappedModel := originalModel
|
mappedModel := originalModel
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||||
mappedModel = account.GetMappedModel(originalModel)
|
mappedModel = account.GetMappedModel(originalModel)
|
||||||
}
|
}
|
||||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||||
|
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||||
|
if normalized != originalModel {
|
||||||
|
mappedModel = normalized
|
||||||
|
}
|
||||||
|
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
normalized := claude.NormalizeModelID(originalModel)
|
normalized := claude.NormalizeModelID(originalModel)
|
||||||
if normalized != originalModel {
|
if normalized != originalModel {
|
||||||
mappedModel = normalized
|
mappedModel = normalized
|
||||||
|
|||||||
@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
|
|||||||
// 4. Model mapping
|
// 4. Model mapping
|
||||||
mappedModel := originalModel
|
mappedModel := originalModel
|
||||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||||
mappedModel = account.GetMappedModel(originalModel)
|
mappedModel = account.GetMappedModel(originalModel)
|
||||||
}
|
}
|
||||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||||
|
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||||
|
if normalized != originalModel {
|
||||||
|
mappedModel = normalized
|
||||||
|
}
|
||||||
|
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
normalized := claude.NormalizeModelID(originalModel)
|
normalized := claude.NormalizeModelID(originalModel)
|
||||||
if normalized != originalModel {
|
if normalized != originalModel {
|
||||||
mappedModel = normalized
|
mappedModel = normalized
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
mathrand "math/rand"
|
mathrand "math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -3597,7 +3599,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
}
|
}
|
||||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
if account.Type == AccountTypeServiceAccount {
|
||||||
|
requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
|
||||||
|
} else {
|
||||||
|
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 其他平台使用账户的模型支持检查
|
// 其他平台使用账户的模型支持检查
|
||||||
return account.IsModelSupported(requestedModel)
|
return account.IsModelSupported(requestedModel)
|
||||||
@@ -3617,6 +3623,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
|||||||
return apiKey, "apikey", nil
|
return apiKey, "apikey", nil
|
||||||
case AccountTypeBedrock:
|
case AccountTypeBedrock:
|
||||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||||
|
case AccountTypeServiceAccount:
|
||||||
|
if account.Platform != PlatformAnthropic {
|
||||||
|
return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
|
||||||
|
}
|
||||||
|
if s.claudeTokenProvider == nil {
|
||||||
|
return "", "", errors.New("claude token provider not configured")
|
||||||
|
}
|
||||||
|
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return accessToken, "service_account", nil
|
||||||
default:
|
default:
|
||||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||||
}
|
}
|
||||||
@@ -4219,6 +4237,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
mappingSource = "account"
|
mappingSource = "account"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||||
|
if candidate, matched := account.ResolveMappedModel(reqModel); matched {
|
||||||
|
mappedModel = candidate
|
||||||
|
mappingSource = "account"
|
||||||
|
} else {
|
||||||
|
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
|
||||||
|
if normalized != reqModel {
|
||||||
|
mappedModel = normalized
|
||||||
|
mappingSource = "vertex"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
normalized := claude.NormalizeModelID(reqModel)
|
normalized := claude.NormalizeModelID(reqModel)
|
||||||
if normalized != reqModel {
|
if normalized != reqModel {
|
||||||
@@ -5688,6 +5718,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||||
|
if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||||
|
return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
|
||||||
|
}
|
||||||
|
|
||||||
// 确定目标URL
|
// 确定目标URL
|
||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
@@ -5874,6 +5908,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
body []byte,
|
||||||
|
token string,
|
||||||
|
modelID string,
|
||||||
|
reqStream bool,
|
||||||
|
) (*http.Request, error) {
|
||||||
|
vertexBody, err := buildVertexAnthropicRequestBody(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
setOpsUpstreamRequestBody(c, vertexBody)
|
||||||
|
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
for key, values := range c.Request.Header {
|
||||||
|
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||||
|
if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wireKey := resolveWireCasing(key)
|
||||||
|
for _, v := range values {
|
||||||
|
addHeaderRaw(req.Header, wireKey, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Del("authorization")
|
||||||
|
req.Header.Del("x-api-key")
|
||||||
|
req.Header.Del("x-goog-api-key")
|
||||||
|
req.Header.Del("cookie")
|
||||||
|
req.Header.Del("anthropic-version")
|
||||||
|
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
|
||||||
|
setHeaderRaw(req.Header, "content-type", "application/json")
|
||||||
|
|
||||||
|
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
|
||||||
|
"url": req.URL.String(),
|
||||||
|
"token_type": "service_account",
|
||||||
|
"model": modelID,
|
||||||
|
"stream": strconv.FormatBool(reqStream),
|
||||||
|
})
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
// getBetaHeader 处理anthropic-beta header
|
// getBetaHeader 处理anthropic-beta header
|
||||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||||
@@ -6434,6 +6522,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
|
||||||
|
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||||
|
// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
|
||||||
|
// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
|
||||||
|
func sanitizeStreamError(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||||
|
return "unexpected EOF"
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return "EOF"
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
|
return "canceled"
|
||||||
|
case errors.Is(err, context.DeadlineExceeded):
|
||||||
|
return "deadline exceeded"
|
||||||
|
case errors.Is(err, syscall.ECONNRESET):
|
||||||
|
return "connection reset by peer"
|
||||||
|
case errors.Is(err, syscall.ECONNABORTED):
|
||||||
|
return "connection aborted"
|
||||||
|
case errors.Is(err, syscall.ETIMEDOUT):
|
||||||
|
return "connection timed out"
|
||||||
|
case errors.Is(err, syscall.EPIPE):
|
||||||
|
return "broken pipe"
|
||||||
|
case errors.Is(err, syscall.ECONNREFUSED):
|
||||||
|
return "connection refused"
|
||||||
|
}
|
||||||
|
var netErr *net.OpError
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
if netErr.Timeout() {
|
||||||
|
if netErr.Op != "" {
|
||||||
|
return netErr.Op + " timeout"
|
||||||
|
}
|
||||||
|
return "i/o timeout"
|
||||||
|
}
|
||||||
|
if netErr.Op != "" {
|
||||||
|
return netErr.Op + " network error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "upstream connection error"
|
||||||
|
}
|
||||||
|
|
||||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||||
@@ -6871,14 +7002,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
|
||||||
|
// 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":<reason>,"message":<message>}}
|
||||||
|
// 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
|
||||||
|
// 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason, message string) {
|
||||||
if errorEventSent {
|
if errorEventSent {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
if message == "" {
|
||||||
|
message = reason
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": reason,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
|
||||||
|
body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7038,10 +7186,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
// 客户端未断开,正常的错误处理
|
// 客户端未断开,正常的错误处理
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||||
}
|
}
|
||||||
sendErrorEvent("stream_read_error")
|
// 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
|
||||||
|
// 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
|
||||||
|
// 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
|
||||||
|
// 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
|
||||||
|
// 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
|
||||||
|
// 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
|
||||||
|
disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
|
||||||
|
if !c.Writer.Written() {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": "upstream_disconnected",
|
||||||
|
"message": disconnectMsg,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ResponseBody: body,
|
||||||
|
RetryableOnSameAccount: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendErrorEvent("stream_read_error", disconnectMsg)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
line := ev.line
|
line := ev.line
|
||||||
@@ -7100,7 +7270,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
}
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
case <-keepaliveCh:
|
case <-keepaliveCh:
|
||||||
|
|||||||
@@ -4,9 +4,12 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
|
|||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
|
||||||
|
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
|
||||||
|
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
|
||||||
|
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||||
|
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
|
||||||
|
|
||||||
|
// ResponseBody 必须是 Anthropic 标准 error 格式:
|
||||||
|
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
|
||||||
|
// 2) error.type 标记为 upstream_disconnected
|
||||||
|
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
|
||||||
|
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
|
||||||
|
require.Contains(t, extractedMsg, "upstream stream disconnected")
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
|
||||||
|
|
||||||
|
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
|
||||||
|
require.NotContains(t, rec.Body.String(), "stream_read_error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
|
||||||
|
// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
|
||||||
|
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{
|
||||||
|
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
|
||||||
|
err: io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
|
||||||
|
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
|
||||||
|
|
||||||
|
// 不应被错误地包成 failover error
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
|
||||||
|
|
||||||
|
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
|
||||||
|
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
|
||||||
|
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
|
||||||
|
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||||
|
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
|
||||||
|
// failover ResponseBody 或 SSE error 帧返回给客户端。
|
||||||
|
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
|
||||||
|
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||||
|
require.NoError(t, err)
|
||||||
|
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
raw := &net.OpError{
|
||||||
|
Op: "read",
|
||||||
|
Net: "tcp",
|
||||||
|
Source: src,
|
||||||
|
Addr: dst,
|
||||||
|
Err: syscall.ECONNRESET,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
|
||||||
|
require.Contains(t, raw.Error(), "10.0.0.1")
|
||||||
|
require.Contains(t, raw.Error(), "52.1.2.3")
|
||||||
|
|
||||||
|
got := sanitizeStreamError(raw)
|
||||||
|
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
|
||||||
|
require.NotContains(t, got, "54321", "不得泄露源端口")
|
||||||
|
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
|
||||||
|
require.NotContains(t, got, "443", "不得泄露上游端口")
|
||||||
|
require.Equal(t, "connection reset by peer", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
|
||||||
|
{"EOF", io.EOF, "EOF"},
|
||||||
|
{"context canceled", context.Canceled, "canceled"},
|
||||||
|
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
|
||||||
|
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
|
||||||
|
{"EPIPE", syscall.EPIPE, "broken pipe"},
|
||||||
|
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
|
||||||
|
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
|
||||||
|
{"nil 返回空串", nil, ""},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
|
||||||
|
// 时携带内部地址信息。
|
||||||
|
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||||
|
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||||
|
netErr := &net.OpError{
|
||||||
|
Op: "read",
|
||||||
|
Net: "tcp",
|
||||||
|
Source: src,
|
||||||
|
Addr: dst,
|
||||||
|
Err: syscall.ECONNRESET,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{err: netErr},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.True(t, errors.As(err, &failoverErr))
|
||||||
|
|
||||||
|
body := string(failoverErr.ResponseBody)
|
||||||
|
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
|
||||||
|
require.NotContains(t, body, "54321")
|
||||||
|
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
|
||||||
|
require.NotContains(t, body, "443")
|
||||||
|
// 仍然包含可诊断的根因
|
||||||
|
require.Contains(t, body, "connection reset by peer")
|
||||||
|
require.Contains(t, body, "upstream stream disconnected")
|
||||||
|
}
|
||||||
|
|||||||
@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
|||||||
}
|
}
|
||||||
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
||||||
return 3
|
return 3
|
||||||
|
case AccountTypeServiceAccount:
|
||||||
|
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
|
||||||
|
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
|
||||||
|
return 999
|
||||||
default:
|
default:
|
||||||
return 10
|
return 10
|
||||||
}
|
}
|
||||||
@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
originalModel := req.Model
|
originalModel := req.Model
|
||||||
mappedModel := req.Model
|
mappedModel := req.Model
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||||
mappedModel = account.GetMappedModel(req.Model)
|
mappedModel = account.GetMappedModel(req.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
requestIDHeader = "x-request-id"
|
requestIDHeader = "x-request-id"
|
||||||
|
|
||||||
|
case AccountTypeServiceAccount:
|
||||||
|
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||||
|
if s.tokenProvider == nil {
|
||||||
|
return nil, "", errors.New("gemini token provider not configured")
|
||||||
|
}
|
||||||
|
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
action := "generateContent"
|
||||||
|
if req.Stream {
|
||||||
|
action = "streamGenerateContent"
|
||||||
|
}
|
||||||
|
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||||
|
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
return upstreamReq, "x-request-id", nil
|
||||||
|
}
|
||||||
|
requestIDHeader = "x-request-id"
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||||||
}
|
}
|
||||||
@@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||||
|
|
||||||
mappedModel := originalModel
|
mappedModel := originalModel
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||||
mappedModel = account.GetMappedModel(originalModel)
|
mappedModel = account.GetMappedModel(originalModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
requestIDHeader = "x-request-id"
|
requestIDHeader = "x-request-id"
|
||||||
|
|
||||||
|
case AccountTypeServiceAccount:
|
||||||
|
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||||
|
if s.tokenProvider == nil {
|
||||||
|
return nil, "", errors.New("gemini token provider not configured")
|
||||||
|
}
|
||||||
|
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||||
|
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
return upstreamReq, "x-request-id", nil
|
||||||
|
}
|
||||||
|
requestIDHeader = "x-request-id"
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ const (
|
|||||||
geminiTokenCacheSkew = 5 * time.Minute
|
geminiTokenCacheSkew = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
|
||||||
type GeminiTokenProvider struct {
|
type GeminiTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache GeminiTokenCache
|
tokenCache GeminiTokenCache
|
||||||
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||||
return "", errors.New("not a gemini oauth account")
|
return "", errors.New("not a gemini oauth or service account")
|
||||||
|
}
|
||||||
|
if account.Type == AccountTypeServiceAccount {
|
||||||
|
return p.getServiceAccountAccessToken(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := GeminiTokenCacheKey(account)
|
cacheKey := GeminiTokenCacheKey(account)
|
||||||
@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||||
|
}
|
||||||
|
|
||||||
func GeminiTokenCacheKey(account *Account) string {
|
func GeminiTokenCacheKey(account *Account) string {
|
||||||
|
if account != nil && account.Type == AccountTypeServiceAccount {
|
||||||
|
if key, err := parseVertexServiceAccountKey(account); err == nil {
|
||||||
|
return vertexServiceAccountCacheKey(account, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return "gemini:" + projectID
|
return "gemini:" + projectID
|
||||||
|
|||||||
@@ -53,6 +53,23 @@ const (
|
|||||||
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
|
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var openAIChatGPTInternalUnsupportedFields = []string{
|
||||||
|
"user",
|
||||||
|
"metadata",
|
||||||
|
"prompt_cache_retention",
|
||||||
|
"safety_identifier",
|
||||||
|
"stream_options",
|
||||||
|
}
|
||||||
|
|
||||||
|
var openAICodexOAuthUnsupportedFields = append([]string{
|
||||||
|
"max_output_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
}, openAIChatGPTInternalUnsupportedFields...)
|
||||||
|
|
||||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||||
result := codexTransformResult{}
|
result := codexTransformResult{}
|
||||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
@@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strip parameters unsupported by codex models via the Responses API.
|
// Strip parameters unsupported by ChatGPT internal Codex endpoint.
|
||||||
for _, key := range []string{
|
for _, key := range openAICodexOAuthUnsupportedFields {
|
||||||
"max_output_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"frequency_penalty",
|
|
||||||
"presence_penalty",
|
|
||||||
// prompt_cache_retention is a newer Responses API parameter (cache TTL).
|
|
||||||
// The ChatGPT internal Codex endpoint rejects it with
|
|
||||||
// "Unsupported parameter: prompt_cache_retention". Defense-in-depth
|
|
||||||
// for any OAuth path that reaches this transform — the Cursor
|
|
||||||
// Responses-shape short-circuit in ForwardAsChatCompletions strips
|
|
||||||
// it earlier too, but we keep this line so other OAuth callers are
|
|
||||||
// equally protected.
|
|
||||||
"prompt_cache_retention",
|
|
||||||
} {
|
|
||||||
if _, ok := reqBody[key]; ok {
|
if _, ok := reqBody[key]; ok {
|
||||||
delete(reqBody, key)
|
delete(reqBody, key)
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
@@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||||
reqBody["tool_choice"] = map[string]any{
|
reqBody["tool_choice"] = map[string]any{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": map[string]any{
|
"name": name,
|
||||||
"name": name,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
|
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
|
||||||
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
|
if choiceType == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
modified := false
|
||||||
|
if choiceType == "function" {
|
||||||
|
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
|
||||||
|
if name == "" {
|
||||||
|
if function, ok := choiceMap["function"].(map[string]any); ok {
|
||||||
|
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
reqBody["tool_choice"] = "auto"
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
|
||||||
|
choiceMap["name"] = name
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
if _, ok := choiceMap["function"]; ok {
|
||||||
|
delete(choiceMap, "function")
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
if !codexToolsContainFunctionName(reqBody["tools"], name) {
|
||||||
|
reqBody["tool_choice"] = "auto"
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
if codexToolsContainType(reqBody["tools"], choiceType) {
|
||||||
|
return modified
|
||||||
|
}
|
||||||
reqBody["tool_choice"] = "auto"
|
reqBody["tool_choice"] = "auto"
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func codexToolsContainFunctionName(rawTools any, name string) bool {
|
||||||
|
tools, ok := rawTools.([]any)
|
||||||
|
if !ok || strings.TrimSpace(name) == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
normalizedName := strings.TrimSpace(name)
|
||||||
|
for _, rawTool := range tools {
|
||||||
|
tool, ok := rawTool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
|
||||||
|
if toolName == "" {
|
||||||
|
if function, ok := tool["function"].(map[string]any); ok {
|
||||||
|
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolName == normalizedName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
|
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
|
||||||
if len(input) == 0 {
|
if len(input) == 0 {
|
||||||
return input, false
|
return input, false
|
||||||
@@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
|||||||
}
|
}
|
||||||
typ, _ := m["type"].(string)
|
typ, _ := m["type"].(string)
|
||||||
|
|
||||||
|
// chatgpt.com codex backend (OAuth path) does not persist reasoning
|
||||||
|
// items because applyCodexOAuthTransform forces store=false. Any rs_*
|
||||||
|
// reference replayed in input is guaranteed to 404 upstream
|
||||||
|
// ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
|
||||||
|
if typ == "reasoning" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||||
fixCallIDPrefix := func(id string) string {
|
fixCallIDPrefix := func(id string) string {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
|
|||||||
require.Equal(t, "custom", choice["type"])
|
require.Equal(t, "custom", choice["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.4",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"type": "function", "name": "shell"},
|
||||||
|
},
|
||||||
|
"tool_choice": map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{"name": "shell"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody, true, false)
|
||||||
|
|
||||||
|
choice, ok := reqBody["tool_choice"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "function", choice["type"])
|
||||||
|
require.Equal(t, "shell", choice["name"])
|
||||||
|
require.NotContains(t, choice, "function")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.4",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"type": "function", "name": "shell"},
|
||||||
|
},
|
||||||
|
"tool_choice": map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{"name": "missing"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody, true, false)
|
||||||
|
|
||||||
|
require.Equal(t, "auto", reqBody["tool_choice"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
|
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"model": "gpt-5.4",
|
"model": "gpt-5.4",
|
||||||
@@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
|
|||||||
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
|
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.4",
|
||||||
|
"user": "user_123",
|
||||||
|
"metadata": map[string]any{"trace_id": "abc"},
|
||||||
|
"prompt_cache_retention": "24h",
|
||||||
|
"safety_identifier": "sid",
|
||||||
|
"stream_options": map[string]any{"include_usage": true},
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, true, false)
|
||||||
|
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||||
|
require.NotContains(t, reqBody, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
|
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"model": "gpt-5.1",
|
"model": "gpt-5.1",
|
||||||
@@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
|
||||||
|
// Reasoning items in input[] reference rs_* IDs that were emitted by
|
||||||
|
// chatgpt.com under store=false (forced by applyCodexOAuthTransform).
|
||||||
|
// They are never persisted upstream, so forwarding them produces a
|
||||||
|
// guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
|
||||||
|
// regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
|
||||||
|
|
||||||
|
build := func() []any {
|
||||||
|
return []any{
|
||||||
|
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "reasoning",
|
||||||
|
"id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
|
||||||
|
"summary": []any{},
|
||||||
|
},
|
||||||
|
map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, preserve := range []bool{true, false} {
|
||||||
|
preserve := preserve
|
||||||
|
t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
|
||||||
|
filtered := filterCodexInput(build(), preserve)
|
||||||
|
|
||||||
|
for _, raw := range filtered {
|
||||||
|
item, ok := raw.(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEqual(t, "reasoning", item["type"],
|
||||||
|
"reasoning items must be dropped from input on the OAuth path")
|
||||||
|
if id, ok := item["id"].(string); ok {
|
||||||
|
require.False(t, strings.HasPrefix(id, "rs_"),
|
||||||
|
"no item carrying an rs_* id should survive the filter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity check: the non-reasoning items should still be present.
|
||||||
|
gotTypes := make(map[string]int)
|
||||||
|
for _, raw := range filtered {
|
||||||
|
item, ok := raw.(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
typ, ok := item["type"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
gotTypes[typ]++
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, gotTypes["message"])
|
||||||
|
require.Equal(t, 1, gotTypes["function_call"])
|
||||||
|
require.Equal(t, 1, gotTypes["function_call_output"])
|
||||||
|
require.Equal(t, 0, gotTypes["reasoning"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
286
backend/internal/service/openai_fast_policy_test.go
Normal file
286
backend/internal/service/openai_fast_policy_test.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIFastPolicyRepoStub struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
if v, ok := s.values[key]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
if s.values == nil {
|
||||||
|
s.values = map[string]string{}
|
||||||
|
}
|
||||||
|
s.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
||||||
|
t.Helper()
|
||||||
|
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||||
|
if settings != nil {
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
||||||
|
}
|
||||||
|
return &OpenAIGatewayService{
|
||||||
|
settingService: NewSettingService(repo, &config.Config{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
||||||
|
// 是用户级开关,与 model 正交。
|
||||||
|
// gpt-5.5 + priority → filter
|
||||||
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionFilter, action)
|
||||||
|
|
||||||
|
// gpt-5.5-turbo → filter
|
||||||
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionFilter, action)
|
||||||
|
|
||||||
|
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
||||||
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionFilter, action)
|
||||||
|
|
||||||
|
// gpt-5.5 + flex → pass (tier doesn't match)
|
||||||
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
||||||
|
require.Equal(t, BetaPolicyActionPass, action)
|
||||||
|
|
||||||
|
// empty tier → pass
|
||||||
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
||||||
|
require.Equal(t, BetaPolicyActionPass, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
||||||
|
settings := &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierPriority,
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "fast mode is not allowed",
|
||||||
|
ModelWhitelist: []string{"gpt-5.5"},
|
||||||
|
FallbackAction: BetaPolicyActionPass,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionBlock, action)
|
||||||
|
require.Equal(t, "fast mode is not allowed", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
||||||
|
settings := &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierAny,
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeOAuth,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||||
|
|
||||||
|
// OAuth account → rule matches
|
||||||
|
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionFilter, action)
|
||||||
|
|
||||||
|
// API Key account → rule skipped → pass
|
||||||
|
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
||||||
|
require.Equal(t, BetaPolicyActionPass, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
// gpt-5.5 fast → service_tier stripped
|
||||||
|
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
||||||
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(updated), `"service_tier"`)
|
||||||
|
|
||||||
|
// Client sending "fast" (alias for priority) also filtered
|
||||||
|
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
||||||
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(updated), `"service_tier"`)
|
||||||
|
|
||||||
|
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
||||||
|
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
||||||
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(updated), `"service_tier"`)
|
||||||
|
|
||||||
|
// No service_tier → no-op
|
||||||
|
body = []byte(`{"model":"gpt-5.5"}`)
|
||||||
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, string(body), string(updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
||||||
|
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
||||||
|
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
||||||
|
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
for _, tier := range []string{"auto", "default", "scale"} {
|
||||||
|
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||||
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err, "tier %q should pass without error", tier)
|
||||||
|
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
||||||
|
"tier %q should be preserved in body under default rule", tier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
||||||
|
for _, tier := range []string{"auto", "default", "scale"} {
|
||||||
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
||||||
|
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
||||||
|
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
||||||
|
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
||||||
|
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
||||||
|
settings := &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierAny,
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
||||||
|
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||||
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, string(updated), `"service_tier"`,
|
||||||
|
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
||||||
|
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
||||||
|
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
||||||
|
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
||||||
|
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
// normalize 阶段会将未知值剥离
|
||||||
|
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||||
|
|
||||||
|
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
||||||
|
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
||||||
|
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
||||||
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, string(body), string(updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
||||||
|
settings := &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierPriority,
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
||||||
|
ModelWhitelist: []string{"gpt-5.5"},
|
||||||
|
FallbackAction: BetaPolicyActionPass,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
||||||
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||||
|
require.Error(t, err)
|
||||||
|
var blocked *OpenAIFastBlockedError
|
||||||
|
require.True(t, errors.As(err, &blocked))
|
||||||
|
require.Contains(t, blocked.Message, "fast mode is blocked")
|
||||||
|
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
||||||
|
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
// Invalid action rejected
|
||||||
|
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierPriority,
|
||||||
|
Action: "bogus",
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Invalid service_tier rejected
|
||||||
|
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: "turbo",
|
||||||
|
Action: BetaPolicyActionPass,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Valid settings persisted
|
||||||
|
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{{
|
||||||
|
ServiceTier: OpenAIFastTierPriority,
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, got.Rules, 1)
|
||||||
|
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
||||||
|
}
|
||||||
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||||
|
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||||
|
if policyErr != nil {
|
||||||
|
var blocked *OpenAIFastBlockedError
|
||||||
|
if errors.As(policyErr, &blocked) {
|
||||||
|
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||||
|
}
|
||||||
|
return nil, policyErr
|
||||||
|
}
|
||||||
|
responsesBody = updatedBody
|
||||||
|
|
||||||
// 5. Get access token
|
// 5. Get access token
|
||||||
token, _, err := s.GetAccessToken(ctx, account)
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
|||||||
normalizeResponsesRequestServiceTier(req)
|
normalizeResponsesRequestServiceTier(req)
|
||||||
require.Equal(t, "flex", req.ServiceTier)
|
require.Equal(t, "flex", req.ServiceTier)
|
||||||
|
|
||||||
|
// OpenAI 官方合法 tier 应被透传保留。
|
||||||
|
req.ServiceTier = "auto"
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Equal(t, "auto", req.ServiceTier)
|
||||||
|
|
||||||
req.ServiceTier = "default"
|
req.ServiceTier = "default"
|
||||||
normalizeResponsesRequestServiceTier(req)
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Equal(t, "default", req.ServiceTier)
|
||||||
|
|
||||||
|
req.ServiceTier = "scale"
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Equal(t, "scale", req.ServiceTier)
|
||||||
|
|
||||||
|
// 真未知值仍被剥离。
|
||||||
|
req.ServiceTier = "turbo"
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
require.Empty(t, req.ServiceTier)
|
require.Empty(t, req.ServiceTier)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
|||||||
require.Equal(t, "flex", tier)
|
require.Equal(t, "flex", tier)
|
||||||
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
|
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
|
||||||
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "auto", tier)
|
||||||
|
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "default", tier)
|
||||||
|
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "scale", tier)
|
||||||
|
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
|
// 真未知值才会被删除。
|
||||||
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
require.Empty(t, tier)
|
require.Empty(t, tier)
|
||||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||||
|
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
|
||||||
|
// on the body-level service_tier field (priority/flex).
|
||||||
|
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||||
|
if policyErr != nil {
|
||||||
|
var blocked *OpenAIFastBlockedError
|
||||||
|
if errors.As(policyErr, &blocked) {
|
||||||
|
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
|
||||||
|
}
|
||||||
|
return nil, policyErr
|
||||||
|
}
|
||||||
|
responsesBody = updatedBody
|
||||||
|
|
||||||
// 5. Get access token
|
// 5. Get access token
|
||||||
token, _, err := s.GetAccessToken(ctx, account)
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||||
rateRepo,
|
rateRepo,
|
||||||
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
|||||||
require.Equal(t, "priority", *got)
|
require.Equal(t, "priority", *got)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("default ignored", func(t *testing.T) {
|
t.Run("openai official tiers preserved", func(t *testing.T) {
|
||||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
|
||||||
|
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
|
||||||
|
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
|
||||||
|
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
|
||||||
|
got := normalizeOpenAIServiceTier(tier)
|
||||||
|
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
|
||||||
|
require.Equal(t, tier, *got)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid ignored", func(t *testing.T) {
|
t.Run("invalid ignored", func(t *testing.T) {
|
||||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||||
|
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||||
|
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
|
||||||
|
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
|
||||||
|
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
|
||||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||||
}
|
}
|
||||||
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
|
|||||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
|
||||||
|
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||||
|
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
|
||||||
|
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
|
||||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
|
|||||||
resolver *ModelPricingResolver
|
resolver *ModelPricingResolver
|
||||||
channelService *ChannelService
|
channelService *ChannelService
|
||||||
balanceNotifyService *BalanceNotifyService
|
balanceNotifyService *BalanceNotifyService
|
||||||
|
settingService *SettingService
|
||||||
|
|
||||||
openaiWSPoolOnce sync.Once
|
openaiWSPoolOnce sync.Once
|
||||||
openaiWSStateStoreOnce sync.Once
|
openaiWSStateStoreOnce sync.Once
|
||||||
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
|
|||||||
resolver *ModelPricingResolver,
|
resolver *ModelPricingResolver,
|
||||||
channelService *ChannelService,
|
channelService *ChannelService,
|
||||||
balanceNotifyService *BalanceNotifyService,
|
balanceNotifyService *BalanceNotifyService,
|
||||||
|
settingService *SettingService,
|
||||||
) *OpenAIGatewayService {
|
) *OpenAIGatewayService {
|
||||||
svc := &OpenAIGatewayService{
|
svc := &OpenAIGatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
|
|||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
channelService: channelService,
|
channelService: channelService,
|
||||||
balanceNotifyService: balanceNotifyService,
|
balanceNotifyService: balanceNotifyService,
|
||||||
|
settingService: settingService,
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
@@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
|
|||||||
return sessionID
|
return sessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||||
|
if sessionID == "" {
|
||||||
|
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||||
|
}
|
||||||
|
if sessionID == "" && len(body) > 0 {
|
||||||
|
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||||
|
}
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
|
||||||
|
// client session signals. It intentionally skips content-derived fallback and is
|
||||||
|
// used by stateless endpoints such as /v1/images.
|
||||||
|
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
|
||||||
|
sessionID := explicitOpenAISessionID(c, body)
|
||||||
|
if sessionID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
|
||||||
|
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
||||||
|
return currentHash
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||||
//
|
//
|
||||||
// Priority:
|
// Priority:
|
||||||
@@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
sessionID := explicitOpenAISessionID(c, body)
|
||||||
if sessionID == "" {
|
|
||||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
|
||||||
}
|
|
||||||
if sessionID == "" && len(body) > 0 {
|
|
||||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|
||||||
}
|
|
||||||
if sessionID == "" && len(body) > 0 {
|
if sessionID == "" && len(body) > 0 {
|
||||||
sessionID = deriveOpenAIContentSessionSeed(body)
|
sessionID = deriveOpenAIContentSessionSeed(body)
|
||||||
}
|
}
|
||||||
@@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
disablePatch()
|
disablePatch()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
|
||||||
|
// 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
|
||||||
|
// 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
|
||||||
|
// fast 时在此生效。
|
||||||
|
//
|
||||||
|
// 注意:
|
||||||
|
// 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
|
||||||
|
// normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
|
||||||
|
// chat-completions / messages 入口保持一致,避免不同入口因为模型
|
||||||
|
// 维度不同而出现 whitelist 命中差异。
|
||||||
|
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
|
||||||
|
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
|
||||||
|
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
|
||||||
|
// 行为,这里手工实现等效逻辑。
|
||||||
|
if rawTier, ok := reqBody["service_tier"].(string); ok {
|
||||||
|
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
|
||||||
|
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
|
||||||
|
switch action {
|
||||||
|
case BetaPolicyActionBlock:
|
||||||
|
msg := errMsg
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
|
||||||
|
}
|
||||||
|
blocked := &OpenAIFastBlockedError{Message: msg}
|
||||||
|
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||||
|
return nil, blocked
|
||||||
|
case BetaPolicyActionFilter:
|
||||||
|
delete(reqBody, "service_tier")
|
||||||
|
bodyModified = true
|
||||||
|
disablePatch()
|
||||||
|
default:
|
||||||
|
// pass:若客户端传的是别名 "fast",归一化为 "priority"
|
||||||
|
// 后写回 body,确保上游收到的是其能识别的规范值。
|
||||||
|
if normTier != rawTier {
|
||||||
|
reqBody["service_tier"] = normTier
|
||||||
|
bodyModified = true
|
||||||
|
markPatchSet("service_tier", normTier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-serialize body only if modified
|
// Re-serialize body only if modified
|
||||||
if bodyModified {
|
if bodyModified {
|
||||||
serializedByPatch := false
|
serializedByPatch := false
|
||||||
@@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
body = sanitizedBody
|
body = sanitizedBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
|
||||||
|
// 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
|
||||||
|
// OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
|
||||||
|
// 这样可以与 chat-completions / messages / native /responses 入口的
|
||||||
|
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
|
||||||
|
// model 字段时退回 reqModel。
|
||||||
|
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||||
|
if policyModel == "" {
|
||||||
|
policyModel = reqModel
|
||||||
|
}
|
||||||
|
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
|
||||||
|
if policyErr != nil {
|
||||||
|
var blocked *OpenAIFastBlockedError
|
||||||
|
if errors.As(policyErr, &blocked) {
|
||||||
|
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||||
|
}
|
||||||
|
return nil, policyErr
|
||||||
|
}
|
||||||
|
body = updatedBody
|
||||||
|
|
||||||
logger.LegacyPrintf("service.openai_gateway",
|
logger.LegacyPrintf("service.openai_gateway",
|
||||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||||
account.ID,
|
account.ID,
|
||||||
@@ -4841,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
normalized := []byte(`{}`)
|
normalized := []byte(`{}`)
|
||||||
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
|
// Keep the current Codex /compact schema while still dropping request-scoped
|
||||||
|
// fields such as prompt_cache_key, store, and stream.
|
||||||
|
for _, field := range []string{
|
||||||
|
"model",
|
||||||
|
"input",
|
||||||
|
"instructions",
|
||||||
|
"tools",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"reasoning",
|
||||||
|
"text",
|
||||||
|
"previous_response_id",
|
||||||
|
} {
|
||||||
value := gjson.GetBytes(body, field)
|
value := gjson.GetBytes(body, field)
|
||||||
if !value.Exists() {
|
if !value.Exists() {
|
||||||
continue
|
continue
|
||||||
@@ -5454,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
||||||
// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
|
// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
|
||||||
|
// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false
|
||||||
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
|
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return body, false, nil
|
return body, false, nil
|
||||||
@@ -5463,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
|
|||||||
normalized := body
|
normalized := body
|
||||||
changed := false
|
changed := false
|
||||||
|
|
||||||
|
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||||
|
if value := gjson.GetBytes(normalized, field); !value.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next, err := sjson.DeleteBytes(normalized, field)
|
||||||
|
if err != nil {
|
||||||
|
return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
|
||||||
|
}
|
||||||
|
normalized = next
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
if compact {
|
if compact {
|
||||||
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
|
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
|
||||||
next, err := sjson.DeleteBytes(normalized, "store")
|
next, err := sjson.DeleteBytes(normalized, "store")
|
||||||
@@ -5567,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
|
|||||||
if value == "fast" {
|
if value == "fast" {
|
||||||
value = "priority"
|
value = "priority"
|
||||||
}
|
}
|
||||||
|
// 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
|
||||||
|
// 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
|
||||||
|
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
|
||||||
|
// 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
|
||||||
switch value {
|
switch value {
|
||||||
case "priority", "flex":
|
case "priority", "flex", "auto", "default", "scale":
|
||||||
return &value
|
return &value
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
|
||||||
|
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
|
||||||
|
type OpenAIFastBlockedError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
|
||||||
|
|
||||||
|
// evaluateOpenAIFastPolicy returns the action and error message that should be
|
||||||
|
// applied for a request with the given account/model/service_tier. When the
|
||||||
|
// policy service is unavailable or no rule matches, it returns
|
||||||
|
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
|
||||||
|
//
|
||||||
|
// Matching rules:
|
||||||
|
// - Scope filters by account type (all / oauth / apikey / bedrock)
|
||||||
|
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
|
||||||
|
// - ModelWhitelist narrows the rule to specific models; FallbackAction
|
||||||
|
// handles the non-matching case (default: pass)
|
||||||
|
//
|
||||||
|
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
|
||||||
|
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
|
||||||
|
// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
|
||||||
|
// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
|
||||||
|
// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
|
||||||
|
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
|
||||||
|
// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
|
||||||
|
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
|
||||||
|
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
|
||||||
|
if s == nil || s.settingService == nil {
|
||||||
|
return BetaPolicyActionPass, ""
|
||||||
|
}
|
||||||
|
tier := strings.ToLower(strings.TrimSpace(serviceTier))
|
||||||
|
if tier == "" {
|
||||||
|
return BetaPolicyActionPass, ""
|
||||||
|
}
|
||||||
|
settings := openAIFastPolicySettingsFromContext(ctx)
|
||||||
|
if settings == nil {
|
||||||
|
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
|
||||||
|
if err != nil || fetched == nil {
|
||||||
|
return BetaPolicyActionPass, ""
|
||||||
|
}
|
||||||
|
settings = fetched
|
||||||
|
}
|
||||||
|
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
|
||||||
|
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
|
||||||
|
// the settingService on every frame. See WSSession entry and
|
||||||
|
// openAIFastPolicySettingsFromContext for the caching glue.
|
||||||
|
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
|
||||||
|
if settings == nil {
|
||||||
|
return BetaPolicyActionPass, ""
|
||||||
|
}
|
||||||
|
isOAuth := account != nil && account.IsOAuth()
|
||||||
|
isBedrock := account != nil && account.IsBedrock()
|
||||||
|
for _, rule := range settings.Rules {
|
||||||
|
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||||
|
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eff := BetaPolicyRule{
|
||||||
|
Action: rule.Action,
|
||||||
|
ErrorMessage: rule.ErrorMessage,
|
||||||
|
ModelWhitelist: rule.ModelWhitelist,
|
||||||
|
FallbackAction: rule.FallbackAction,
|
||||||
|
FallbackErrorMessage: rule.FallbackErrorMessage,
|
||||||
|
}
|
||||||
|
return resolveRuleAction(eff, model)
|
||||||
|
}
|
||||||
|
return BetaPolicyActionPass, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
|
||||||
|
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
|
||||||
|
//
|
||||||
|
// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
|
||||||
|
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
|
||||||
|
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
|
||||||
|
// 可以通过踢断 session 强制刷新。
|
||||||
|
type openAIFastPolicyCtxKeyType struct{}
|
||||||
|
|
||||||
|
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
|
||||||
|
|
||||||
|
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
|
||||||
|
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
|
||||||
|
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
|
||||||
|
if ctx == nil || settings == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
|
||||||
|
// body. When action=filter it removes the service_tier field; when
|
||||||
|
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
|
||||||
|
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
|
||||||
|
// rewriting the body so the upstream receives a slug it recognizes.
|
||||||
|
//
|
||||||
|
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
|
||||||
|
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
|
||||||
|
// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
|
||||||
|
// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
|
||||||
|
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
|
||||||
|
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
rawTier := gjson.GetBytes(body, "service_tier").String()
|
||||||
|
if rawTier == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||||
|
if normTier == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||||
|
switch action {
|
||||||
|
case BetaPolicyActionBlock:
|
||||||
|
msg := errMsg
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||||
|
}
|
||||||
|
return body, &OpenAIFastBlockedError{Message: msg}
|
||||||
|
case BetaPolicyActionFilter:
|
||||||
|
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("strip service_tier from body: %w", err)
|
||||||
|
}
|
||||||
|
return trimmed, nil
|
||||||
|
default:
|
||||||
|
// pass:把别名(如 "fast")写回为规范值("priority")。
|
||||||
|
if normTier == rawTier {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetBytes(body, "service_tier", normTier)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
|
||||||
|
}
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
|
||||||
|
// request blocked by the OpenAI fast policy.
|
||||||
|
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
|
||||||
|
if c == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "permission_error",
|
||||||
|
"message": err.Message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
|
||||||
|
// against a single client→upstream WebSocket frame whose top-level
|
||||||
|
// "type"=="response.create". It mirrors the HTTP-side
|
||||||
|
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
|
||||||
|
// WS payload:
|
||||||
|
//
|
||||||
|
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
|
||||||
|
// - filter: returns a copy with top-level service_tier removed
|
||||||
|
// - block: returns (frame, *OpenAIFastBlockedError)
|
||||||
|
//
|
||||||
|
// Only frames whose "type" field strictly equals "response.create" are
|
||||||
|
// inspected/mutated. Any other frame type — including the empty string —
|
||||||
|
// passes through untouched. The OpenAI Realtime client-event spec requires
|
||||||
|
// "type" to be set, so an empty type is treated as a malformed frame we do
|
||||||
|
// not police; the upstream is the source of truth for rejecting it.
|
||||||
|
//
|
||||||
|
// service_tier lives at the top level of response.create — same as the
|
||||||
|
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
|
||||||
|
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
|
||||||
|
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
|
||||||
|
// to inspect / strip the top-level field; there is no nested form in the
|
||||||
|
// schema today.
|
||||||
|
//
|
||||||
|
// The caller is responsible for choosing the upstream model passed in —
|
||||||
|
// this helper does not re-derive it.
|
||||||
|
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
model string,
|
||||||
|
frame []byte,
|
||||||
|
) ([]byte, *OpenAIFastBlockedError, error) {
|
||||||
|
if len(frame) == 0 {
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(frame) {
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
|
||||||
|
// Strict match: only response.create is policy-checked. Empty / other
|
||||||
|
// types pass through untouched so we never accidentally strip fields
|
||||||
|
// from response.cancel, conversation.item.create, or any future
|
||||||
|
// client-event the spec adds. The Realtime spec requires "type" on
|
||||||
|
// every client event, so an empty type is malformed input — let the
|
||||||
|
// upstream reject it rather than guessing at our layer.
|
||||||
|
if frameType != "response.create" {
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
rawTier := gjson.GetBytes(frame, "service_tier").String()
|
||||||
|
if rawTier == "" {
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||||
|
if normTier == "" {
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||||
|
switch action {
|
||||||
|
case BetaPolicyActionBlock:
|
||||||
|
msg := errMsg
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||||
|
}
|
||||||
|
return frame, &OpenAIFastBlockedError{Message: msg}, nil
|
||||||
|
case BetaPolicyActionFilter:
|
||||||
|
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
|
||||||
|
if err != nil {
|
||||||
|
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
|
||||||
|
}
|
||||||
|
return trimmed, nil, nil
|
||||||
|
default:
|
||||||
|
return frame, nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
|
||||||
|
// server-emitted error event. Matches the loose "evt_<rand>" convention used
|
||||||
|
// by upstream Realtime servers; the exact value is not load-bearing and is
|
||||||
|
// only required for client-side log correlation. We reuse the existing
|
||||||
|
// google/uuid dependency rather than pulling a new one.
|
||||||
|
func newOpenAIFastPolicyWSEventID() string {
|
||||||
|
id, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
// Extremely unlikely; fall back to a fixed prefix so the field is
|
||||||
|
// still non-empty and the schema stays self-consistent.
|
||||||
|
return "evt_openai_fast_policy"
|
||||||
|
}
|
||||||
|
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
|
||||||
|
// canonical form, mirroring what real Realtime traces look like.
|
||||||
|
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
|
||||||
|
// style "error" event payload for a request blocked by the OpenAI fast
|
||||||
|
// policy. The shape mirrors Realtime error events as observed in upstream
|
||||||
|
// traces and per the spec's server "error" event:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "event_id": "evt_<random>",
|
||||||
|
// "type": "error",
|
||||||
|
// "error": {
|
||||||
|
// "type": "invalid_request_error",
|
||||||
|
// "code": "policy_violation",
|
||||||
|
// "message": "..."
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// event_id lets clients correlate the rejection in their logs; "code" gives
|
||||||
|
// programmatic clients a stable identifier (HTTP-side equivalent is the
|
||||||
|
// 403 permission_error JSON body).
|
||||||
|
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
eventID := newOpenAIFastPolicyWSEventID()
|
||||||
|
payload, mErr := json.Marshal(map[string]any{
|
||||||
|
"event_id": eventID,
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "policy_violation",
|
||||||
|
"message": err.Message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if mErr != nil {
|
||||||
|
// Fallback to a minimal hand-rolled payload; Marshal of the literal
|
||||||
|
// shape above should never fail in practice.
|
||||||
|
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
|
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
|
||||||
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
|
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
|
||||||
return body, false, nil
|
return body, false, nil
|
||||||
|
|||||||
@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
|
|||||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
|
||||||
|
|
||||||
|
t.Run("stateless image body stays unstuck", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
|
||||||
|
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
|
||||||
|
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
|
||||||
|
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("header overrides body", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
c.Request.Header.Set("session_id", "header-session")
|
||||||
|
|
||||||
|
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
|
||||||
|
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -1732,6 +1767,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
|
||||||
|
|
||||||
|
normalized, changed, err := normalizeOpenAICompactRequestBody(body)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, changed)
|
||||||
|
require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
|
||||||
|
require.True(t, gjson.GetBytes(normalized, "tools").Exists())
|
||||||
|
require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
|
||||||
|
require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
|
||||||
|
require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
|
||||||
|
require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
|
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
|
|||||||
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
|
||||||
|
require.Equal(t,
|
||||||
|
"https://image-upstream.example/v1/images/generations",
|
||||||
|
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
|
||||||
|
)
|
||||||
|
require.Equal(t,
|
||||||
|
"https://image-upstream.example/v1/images/edits",
|
||||||
|
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
|
||||||
|
)
|
||||||
|
require.Equal(t,
|
||||||
|
"https://image-upstream.example/v1/images/generations",
|
||||||
|
buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
|
||||||
|
)
|
||||||
|
require.Equal(t,
|
||||||
|
"https://image-upstream.example/v1/images/generations",
|
||||||
|
buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type openAIImageTestSSEEvent struct {
|
type openAIImageTestSSEEvent struct {
|
||||||
Name string
|
Name string
|
||||||
Data string
|
Data string
|
||||||
@@ -371,6 +391,124 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
|||||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req_img_apikey"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, "gpt-image-2", result.Model)
|
||||||
|
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||||
|
|
||||||
|
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "replace background"))
|
||||||
|
imagePart, err := writer.CreateFormFile("image", "source.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = imagePart.Write([]byte("png-image-content"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req_img_edit_apikey"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
|
||||||
|
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||||
|
require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
|
||||||
|
require.Contains(t, string(upstream.lastBody), `name="model"`)
|
||||||
|
require.Contains(t, string(upstream.lastBody), "gpt-image-2")
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
|
||||||
|
|
||||||
|
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, changed)
|
||||||
|
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||||
|
require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
|
||||||
|
}
|
||||||
|
require.True(t, gjson.GetBytes(normalized, "stream").Bool())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "store").Bool())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
|
||||||
|
|
||||||
|
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, changed)
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "user").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||||
|
}
|
||||||
@@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
|
|||||||
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||||
storeDisabled bool,
|
storeDisabled bool,
|
||||||
turn int,
|
turn int,
|
||||||
hasFunctionCallOutput bool,
|
signals ToolContinuationSignals,
|
||||||
currentPreviousResponseID string,
|
currentPreviousResponseID string,
|
||||||
expectedPreviousResponseID string,
|
expectedPreviousResponseID string,
|
||||||
) bool {
|
) bool {
|
||||||
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
|
if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if signals.HasFunctionCallOutputMissingCallID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// If the client already sent tool-call context or item_reference anchors,
|
||||||
|
// treat this as a full replay / self-contained continuation payload rather
|
||||||
|
// than downgrading it into an inferred delta continuation.
|
||||||
|
if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2366,6 +2375,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
return errors.New("token is empty")
|
return errors.New("token is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
|
||||||
|
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
|
||||||
|
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
|
||||||
|
if s.settingService != nil {
|
||||||
|
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
|
||||||
|
ctx = withOpenAIFastPolicyContext(ctx, settings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||||
ingressMode := OpenAIWSIngressModeCtxPool
|
ingressMode := OpenAIWSIngressModeCtxPool
|
||||||
@@ -2524,6 +2542,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
normalized = next
|
normalized = next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||||
|
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||||
|
// single integration point for all WS ingress turns (first + follow-up
|
||||||
|
// frames flow through here).
|
||||||
|
//
|
||||||
|
// Model fallback: parseClientPayload above rejects any frame whose
|
||||||
|
// "model" field is missing (line ~2493-2500), so by the time we
|
||||||
|
// reach this point upstreamModel is always derived from a non-empty
|
||||||
|
// per-frame model. The capturedSessionModel fallback used in the
|
||||||
|
// passthrough adapter is therefore not needed in this path.
|
||||||
|
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||||
|
if policyErr != nil {
|
||||||
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||||
|
}
|
||||||
|
if blocked != nil {
|
||||||
|
// Send a Realtime-style error event to the client first, then
|
||||||
|
// signal the handler to close the connection with PolicyViolation.
|
||||||
|
// We intentionally do NOT forward this frame upstream.
|
||||||
|
//
|
||||||
|
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
|
||||||
|
// the underlying bufio writer before returning (write.go:42 →
|
||||||
|
// 307-311), and the subsequent close handshake re-acquires the
|
||||||
|
// same writeFrameMu, so the error event is guaranteed to reach
|
||||||
|
// the kernel send buffer before any close frame is queued.
|
||||||
|
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||||
|
if eventBytes != nil {
|
||||||
|
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||||
|
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
blocked.Message,
|
||||||
|
blocked,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
normalized = policyApplied
|
||||||
|
|
||||||
return openAIWSClientPayload{
|
return openAIWSClientPayload{
|
||||||
payloadRaw: normalized,
|
payloadRaw: normalized,
|
||||||
rawForHash: trimmed,
|
rawForHash: trimmed,
|
||||||
@@ -3132,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
skipBeforeTurn = false
|
skipBeforeTurn = false
|
||||||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||||||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||||||
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
toolSignals := ToolContinuationSignals{
|
||||||
|
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
|
||||||
|
}
|
||||||
|
if toolSignals.HasFunctionCallOutput {
|
||||||
|
var currentReqBody map[string]any
|
||||||
|
if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
|
||||||
|
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
|
||||||
// store=false + function_call_output 场景必须有续链锚点。
|
// store=false + function_call_output 场景必须有续链锚点。
|
||||||
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
||||||
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||||
storeDisabled,
|
storeDisabled,
|
||||||
turn,
|
turn,
|
||||||
hasFunctionCallOutput,
|
toolSignals,
|
||||||
currentPreviousResponseID,
|
currentPreviousResponseID,
|
||||||
expectedPrev,
|
expectedPrev,
|
||||||
) {
|
) {
|
||||||
|
|||||||
@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
|||||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
|
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSQueueDialer{
|
||||||
|
conns: []openAIWSClientConn{captureConn},
|
||||||
|
}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 114,
|
||||||
|
Name: "openai-ingress-tool-context",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeMessage := func(payload string) {
|
||||||
|
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||||
|
}
|
||||||
|
readMessage := func() []byte {
|
||||||
|
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
msgType, message, readErr := clientConn.Read(readCtx)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, coderws.MessageText, msgType)
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||||
|
firstTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||||
|
secondTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||||
|
|
||||||
|
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, captureDialer.DialCount())
|
||||||
|
require.Len(t, captureConn.writes, 2)
|
||||||
|
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSQueueDialer{
|
||||||
|
conns: []openAIWSClientConn{captureConn},
|
||||||
|
}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 115,
|
||||||
|
Name: "openai-ingress-item-reference",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeMessage := func(payload string) {
|
||||||
|
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||||
|
}
|
||||||
|
readMessage := func() []byte {
|
||||||
|
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
msgType, message, readErr := clientConn.Read(readCtx)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, coderws.MessageText, msgType)
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||||
|
firstTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||||
|
secondTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||||
|
|
||||||
|
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, captureDialer.DialCount())
|
||||||
|
require.Len(t, captureConn.writes, 2)
|
||||||
|
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||||
|
|||||||
@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
storeDisabled bool
|
storeDisabled bool
|
||||||
turn int
|
turn int
|
||||||
hasFunctionCallOutput bool
|
signals ToolContinuationSignals
|
||||||
currentPreviousResponse string
|
currentPreviousResponse string
|
||||||
expectedPrevious string
|
expectedPrevious string
|
||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "infer_when_all_conditions_match",
|
name: "infer_when_all_conditions_match",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
expectedPrevious: "resp_1",
|
expectedPrevious: "resp_1",
|
||||||
want: true,
|
want: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skip_when_store_enabled",
|
name: "skip_when_store_enabled",
|
||||||
storeDisabled: false,
|
storeDisabled: false,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
expectedPrevious: "resp_1",
|
expectedPrevious: "resp_1",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skip_on_first_turn",
|
name: "skip_on_first_turn",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 1,
|
turn: 1,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
expectedPrevious: "resp_1",
|
expectedPrevious: "resp_1",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skip_without_function_call_output",
|
name: "skip_without_function_call_output",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: false,
|
signals: ToolContinuationSignals{},
|
||||||
expectedPrevious: "resp_1",
|
expectedPrevious: "resp_1",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skip_when_request_already_has_previous_response_id",
|
name: "skip_when_request_already_has_previous_response_id",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
currentPreviousResponse: "resp_client",
|
currentPreviousResponse: "resp_client",
|
||||||
expectedPrevious: "resp_1",
|
expectedPrevious: "resp_1",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skip_when_last_turn_response_id_missing",
|
name: "skip_when_last_turn_response_id_missing",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
expectedPrevious: "",
|
expectedPrevious: "",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "trim_whitespace_before_judgement",
|
name: "trim_whitespace_before_judgement",
|
||||||
storeDisabled: true,
|
storeDisabled: true,
|
||||||
turn: 2,
|
turn: 2,
|
||||||
hasFunctionCallOutput: true,
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||||
expectedPrevious: " resp_2 ",
|
expectedPrevious: " resp_2 ",
|
||||||
want: true,
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip_when_tool_call_context_already_present",
|
||||||
|
storeDisabled: true,
|
||||||
|
turn: 2,
|
||||||
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
|
||||||
|
expectedPrevious: "resp_2",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip_when_item_reference_already_covers_all_call_ids",
|
||||||
|
storeDisabled: true,
|
||||||
|
turn: 2,
|
||||||
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
|
||||||
|
expectedPrevious: "resp_2",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip_when_function_call_output_missing_call_id",
|
||||||
|
storeDisabled: true,
|
||||||
|
turn: 2,
|
||||||
|
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
|
||||||
|
expectedPrevious: "resp_2",
|
||||||
|
want: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
|||||||
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||||
tt.storeDisabled,
|
tt.storeDisabled,
|
||||||
tt.turn,
|
tt.turn,
|
||||||
tt.hasFunctionCallOutput,
|
tt.signals,
|
||||||
tt.currentPreviousResponse,
|
tt.currentPreviousResponse,
|
||||||
tt.expectedPrevious,
|
tt.expectedPrevious,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||||
|
|||||||
@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
|
|||||||
conn *coderws.Conn
|
conn *coderws.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
||||||
|
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
||||||
|
// passthrough-relay equivalent of the parseClientPayload integration in the
|
||||||
|
// ingress session path. filter returns:
|
||||||
|
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
||||||
|
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
||||||
|
// event via onBlock and surfaces a transport-level error so the relay
|
||||||
|
// stops reading from the client.
|
||||||
|
// - _, _, err: a transport error other than block.
|
||||||
|
type openAIWSPolicyEnforcingFrameConn struct {
|
||||||
|
inner openaiwsv2.FrameConn
|
||||||
|
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
||||||
|
onBlock func(blocked *OpenAIFastBlockedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
||||||
|
|
||||||
|
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||||
|
if c == nil || c.inner == nil {
|
||||||
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
msgType, payload, err := c.inner.ReadFrame(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return msgType, payload, err
|
||||||
|
}
|
||||||
|
if c.filter == nil {
|
||||||
|
return msgType, payload, nil
|
||||||
|
}
|
||||||
|
updated, blocked, filterErr := c.filter(msgType, payload)
|
||||||
|
if filterErr != nil {
|
||||||
|
return msgType, payload, filterErr
|
||||||
|
}
|
||||||
|
if blocked != nil {
|
||||||
|
if c.onBlock != nil {
|
||||||
|
c.onBlock(blocked)
|
||||||
|
}
|
||||||
|
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||||
|
}
|
||||||
|
return msgType, updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||||
|
if c == nil || c.inner == nil {
|
||||||
|
return errOpenAIWSConnClosed
|
||||||
|
}
|
||||||
|
return c.inner.WriteFrame(ctx, msgType, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
||||||
|
if c == nil || c.inner == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.inner.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
||||||
|
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
||||||
|
// passthrough WS frame. Mirrors the HTTP-side normalization
|
||||||
|
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
||||||
|
// matches model whitelists identically.
|
||||||
|
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
||||||
|
if account == nil || len(payload) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
if original == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
||||||
|
// derived from a session.update frame's session.model field. Returns "" when
|
||||||
|
// the frame is not a session.update event or carries no session.model. Used
|
||||||
|
// by the per-frame policy filter (client→upstream direction) to keep
|
||||||
|
// capturedSessionModel in sync with the session-level model the client may
|
||||||
|
// rotate mid-session.
|
||||||
|
//
|
||||||
|
// Realtime / Responses WS lets the client change the session model after
|
||||||
|
// the WS handshake via:
|
||||||
|
//
|
||||||
|
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
||||||
|
//
|
||||||
|
// If we only capture the model from the very first frame, a client can ship
|
||||||
|
// gpt-4o on the first response.create (whitelisted as pass), then
|
||||||
|
// session.update to gpt-5.5, then send response.create without "model" so
|
||||||
|
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
||||||
|
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
||||||
|
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
||||||
|
if account == nil || len(payload) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||||
|
if frameType != "session.update" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||||
|
if original == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||||
|
}
|
||||||
|
|
||||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||||
|
|
||||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||||
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
return errors.New("token is empty")
|
return errors.New("token is empty")
|
||||||
}
|
}
|
||||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||||
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
|
|
||||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||||
logOpenAIWSV2Passthrough(
|
logOpenAIWSV2Passthrough(
|
||||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||||
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
len(firstClientMessage),
|
len(firstClientMessage),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
||||||
|
// frames are filtered via a wrapping FrameConn below so every client→
|
||||||
|
// upstream frame goes through the same policy evaluator/normalize/scope as
|
||||||
|
// HTTP entrypoints.
|
||||||
|
//
|
||||||
|
// We capture the session-level model from the first frame here so the
|
||||||
|
// per-frame filter (below) can fall back to it when a follow-up frame
|
||||||
|
// omits "model" — Realtime clients are allowed to send response.create
|
||||||
|
// without re-stating the model, in which case the upstream uses the model
|
||||||
|
// negotiated at session.update time. Without this fallback, an empty
|
||||||
|
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
||||||
|
// silently passed through, defeating the policy on every frame after
|
||||||
|
// the first.
|
||||||
|
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||||
|
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||||
|
if policyErr != nil {
|
||||||
|
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||||
|
}
|
||||||
|
if blocked != nil {
|
||||||
|
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
||||||
|
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
||||||
|
// bufio writer before returning (write.go:42 → write.go:307-311).
|
||||||
|
// The subsequent close handshake re-acquires the same writeFrameMu
|
||||||
|
// to send the close frame, so the error event is guaranteed to
|
||||||
|
// reach the kernel send buffer before any close frame is queued.
|
||||||
|
// No explicit flush hop is required here.
|
||||||
|
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||||
|
if eventBytes != nil {
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||||
|
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||||
|
cancelWrite()
|
||||||
|
}
|
||||||
|
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||||
|
}
|
||||||
|
firstClientMessage = updatedFirst
|
||||||
|
|
||||||
|
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||||
|
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||||
|
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||||
|
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||||
|
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
||||||
|
//
|
||||||
|
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
||||||
|
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
||||||
|
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||||
|
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||||
|
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||||
|
// goroutine)之间同步当前 turn 的 service_tier。
|
||||||
|
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||||
|
// 可直接 Store/Load 而无需额外封装。
|
||||||
|
var requestServiceTierPtr atomic.Pointer[string]
|
||||||
|
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||||
|
|
||||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("build ws url: %w", err)
|
return fmt.Errorf("build ws url: %w", err)
|
||||||
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
}
|
}
|
||||||
|
|
||||||
completedTurns := atomic.Int32{}
|
completedTurns := atomic.Int32{}
|
||||||
|
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
||||||
|
inner: &openAIWSClientFrameConn{conn: clientConn},
|
||||||
|
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
||||||
|
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
||||||
|
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
||||||
|
// 加锁/原子化。
|
||||||
|
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||||
|
if msgType != coderws.MessageText {
|
||||||
|
return payload, nil, nil
|
||||||
|
}
|
||||||
|
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||||||
|
// session.update 修改 session-level model(Realtime /
|
||||||
|
// Responses WS 协议允许),如果不刷新就会出现
|
||||||
|
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
||||||
|
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
||||||
|
// 绕过路径。这里只看 session.update 事件中的 session.model
|
||||||
|
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
||||||
|
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||||
|
capturedSessionModel = updated
|
||||||
|
}
|
||||||
|
// Per-frame model first; if the client omits "model" on a
|
||||||
|
// follow-up frame (legal in Realtime), fall back to the
|
||||||
|
// session-level model captured from the first frame so the
|
||||||
|
// model whitelist still resolves. An empty model would miss
|
||||||
|
// any whitelist and silently fall back to pass.
|
||||||
|
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||||
|
if model == "" {
|
||||||
|
model = capturedSessionModel
|
||||||
|
}
|
||||||
|
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||||
|
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||||
|
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||||
|
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||||
|
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||||
|
// - 非 response.create 帧(response.cancel /
|
||||||
|
// conversation.item.create / session.update 等)不携带
|
||||||
|
// per-response service_tier,不应覆盖前一轮值。
|
||||||
|
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||||
|
// 上一轮值。
|
||||||
|
// - policyErr != nil:异常路径,保持上一轮值。
|
||||||
|
// - 不带 service_tier 的 response.create 会让
|
||||||
|
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
||||||
|
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
||||||
|
// service_tier 时按 default 处理,billing 应如实反映。
|
||||||
|
if policyErr == nil && blocked == nil &&
|
||||||
|
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||||
|
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||||
|
}
|
||||||
|
return out, blocked, policyErr
|
||||||
|
},
|
||||||
|
onBlock: func(blocked *OpenAIFastBlockedError) {
|
||||||
|
// See note above on Conn.Write being synchronous w.r.t. flush;
|
||||||
|
// no explicit flush is required to ensure the error event lands
|
||||||
|
// before the close frame.
|
||||||
|
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||||
|
if eventBytes == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||||
|
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||||
|
cancel()
|
||||||
|
},
|
||||||
|
}
|
||||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
ClientConn: policyClientConn,
|
||||||
UpstreamConn: upstreamFrameConn,
|
UpstreamConn: upstreamFrameConn,
|
||||||
FirstClientMessage: firstClientMessage,
|
FirstClientMessage: firstClientMessage,
|
||||||
Options: openaiwsv2.RelayOptions{
|
Options: openaiwsv2.RelayOptions{
|
||||||
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
},
|
},
|
||||||
Model: turn.RequestModel,
|
Model: turn.RequestModel,
|
||||||
ServiceTier: requestServiceTier,
|
ServiceTier: requestServiceTierPtr.Load(),
|
||||||
Stream: true,
|
Stream: true,
|
||||||
OpenAIWSMode: true,
|
OpenAIWSMode: true,
|
||||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
},
|
},
|
||||||
Model: relayResult.RequestModel,
|
Model: relayResult.RequestModel,
|
||||||
ServiceTier: requestServiceTier,
|
ServiceTier: requestServiceTierPtr.Load(),
|
||||||
Stream: true,
|
Stream: true,
|
||||||
OpenAIWSMode: true,
|
OpenAIWSMode: true,
|
||||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
|||||||
@@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
|
||||||
|
// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
|
||||||
|
// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
|
||||||
|
// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
|
||||||
|
//
|
||||||
|
// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE":
|
||||||
|
// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
|
||||||
|
// - 无 WAL 写入、无后续 VACUUM 压力
|
||||||
|
// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
|
||||||
|
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
|
||||||
|
if days < 0 {
|
||||||
|
return time.Time{}, false, false
|
||||||
|
}
|
||||||
|
if days == 0 {
|
||||||
|
return time.Time{}, true, true
|
||||||
|
}
|
||||||
|
return now.AddDate(0, 0, -days), false, true
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
|
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
|
||||||
out := opsCleanupDeletedCounts{}
|
out := opsCleanupDeletedCounts{}
|
||||||
if s == nil || s.db == nil || s.cfg == nil {
|
if s == nil || s.db == nil || s.cfg == nil {
|
||||||
@@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
|||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
// Error-like tables: error logs / retry attempts / alert events.
|
// runOne 把"truncate? cutoff? batched delete?"封装到一处,
|
||||||
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
|
// 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
|
||||||
cutoff := now.AddDate(0, 0, -days)
|
runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
|
||||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
|
if truncate {
|
||||||
|
return truncateOpsTable(ctx, s.db, table)
|
||||||
|
}
|
||||||
|
return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
|
||||||
|
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
|
||||||
|
n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
out.errorLogs = n
|
out.errorLogs = n
|
||||||
|
|
||||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
|
n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
out.retryAttempts = n
|
out.retryAttempts = n
|
||||||
|
|
||||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
|
n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
out.alertEvents = n
|
out.alertEvents = n
|
||||||
|
|
||||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
|
n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
out.systemLogs = n
|
out.systemLogs = n
|
||||||
|
|
||||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
|
n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
@@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Minute-level metrics snapshots.
|
// Minute-level metrics snapshots.
|
||||||
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
|
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
|
||||||
cutoff := now.AddDate(0, 0, -days)
|
n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
|
||||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
@@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pre-aggregation tables (hourly/daily).
|
// Pre-aggregation tables (hourly/daily).
|
||||||
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
|
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
|
||||||
cutoff := now.AddDate(0, 0, -days)
|
n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
|
||||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
out.hourlyPreagg = n
|
out.hourlyPreagg = n
|
||||||
|
|
||||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
|
n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return out, err
|
||||||
}
|
}
|
||||||
@@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
|
|||||||
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If ops tables aren't present yet (partial deployments), treat as no-op.
|
// If ops tables aren't present yet (partial deployments), treat as no-op.
|
||||||
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
|
if isMissingRelationError(err) {
|
||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
return total, err
|
return total, err
|
||||||
@@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
|
|||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
|
||||||
|
//
|
||||||
|
// 与 deleteOldRowsByID 的差异:
|
||||||
|
// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
|
||||||
|
// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
|
||||||
|
// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
|
||||||
|
//
|
||||||
|
// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。
|
||||||
|
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
|
||||||
|
if db == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
|
||||||
|
if isMissingRelationError(err) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("count %s: %w", table, err)
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
|
||||||
|
if isMissingRelationError(err) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("truncate %s: %w", table, err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
|
||||||
|
func isMissingRelationError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
|
|||||||
64
backend/internal/service/ops_cleanup_service_test.go
Normal file
64
backend/internal/service/ops_cleanup_service_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpsCleanupPlan(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
days int
|
||||||
|
wantOK bool
|
||||||
|
wantTruncate bool
|
||||||
|
wantCutoff time.Time
|
||||||
|
}{
|
||||||
|
{name: "negative skips", days: -1, wantOK: false},
|
||||||
|
{name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
|
||||||
|
{name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
|
||||||
|
if ok != tc.wantOK {
|
||||||
|
t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if truncate != tc.wantTruncate {
|
||||||
|
t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
|
||||||
|
}
|
||||||
|
if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
|
||||||
|
t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMissingRelationError(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "nil is not missing", err: nil, want: false},
|
||||||
|
{name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
|
||||||
|
{name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
|
||||||
|
{name: "non-matching error", err: fakeErr("connection refused"), want: false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := isMissingRelationError(tc.err); got != tc.want {
|
||||||
|
t.Fatalf("got %v, want %v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeErr string
|
||||||
|
|
||||||
|
func (e fakeErr) Error() string { return string(e) }
|
||||||
@@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
|
|||||||
if cfg.DataRetention.CleanupSchedule == "" {
|
if cfg.DataRetention.CleanupSchedule == "" {
|
||||||
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
|
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
|
// 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
|
||||||
|
// 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
|
||||||
|
if cfg.DataRetention.ErrorLogRetentionDays < 0 {
|
||||||
cfg.DataRetention.ErrorLogRetentionDays = 30
|
cfg.DataRetention.ErrorLogRetentionDays = 30
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
|
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
|
||||||
cfg.DataRetention.MinuteMetricsRetentionDays = 30
|
cfg.DataRetention.MinuteMetricsRetentionDays = 30
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
|
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
|
||||||
cfg.DataRetention.HourlyMetricsRetentionDays = 30
|
cfg.DataRetention.HourlyMetricsRetentionDays = 30
|
||||||
}
|
}
|
||||||
// Normalize auto refresh interval (default 30 seconds)
|
// Normalize auto refresh interval (default 30 seconds)
|
||||||
@@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
|
|||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return errors.New("invalid config")
|
return errors.New("invalid config")
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
|
// 保留天数:0 表示每次清理全部,1-365 表示按天数保留。
|
||||||
return errors.New("error_log_retention_days must be between 1 and 365")
|
if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
|
||||||
|
return errors.New("error_log_retention_days must be between 0 and 365")
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
|
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
|
||||||
return errors.New("minute_metrics_retention_days must be between 1 and 365")
|
return errors.New("minute_metrics_retention_days must be between 0 and 365")
|
||||||
}
|
}
|
||||||
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
|
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
|
||||||
return errors.New("hourly_metrics_retention_days must be between 1 and 365")
|
return errors.New("hourly_metrics_retention_days must be between 0 and 365")
|
||||||
}
|
}
|
||||||
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
|
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
|
||||||
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
|
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ type SchedulerCache interface {
|
|||||||
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||||
// TryLockBucket 尝试获取分桶重建锁。
|
// TryLockBucket 尝试获取分桶重建锁。
|
||||||
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
||||||
|
// UnlockBucket 释放分桶重建锁。
|
||||||
|
UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
|
||||||
// ListBuckets 返回已注册的分桶集合。
|
// ListBuckets 返回已注册的分桶集合。
|
||||||
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
||||||
// GetOutboxWatermark 读取 outbox 水位。
|
// GetOutboxWatermark 读取 outbox 水位。
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = s.cache.UnlockBucket(ctx, bucket)
|
||||||
|
}()
|
||||||
|
|
||||||
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
|||||||
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
|
||||||
|
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSettingNotFound) {
|
||||||
|
return DefaultOpenAIFastPolicySettings(), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return DefaultOpenAIFastPolicySettings(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings OpenAIFastPolicySettings
|
||||||
|
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||||
|
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
|
||||||
|
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
|
||||||
|
// 行为时定位到 settings 表里的脏数据。
|
||||||
|
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
|
||||||
|
"error", err,
|
||||||
|
"key", SettingKeyOpenAIFastPolicySettings)
|
||||||
|
return DefaultOpenAIFastPolicySettings(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
|
||||||
|
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
|
||||||
|
if settings == nil {
|
||||||
|
return fmt.Errorf("settings cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
validActions := map[string]bool{
|
||||||
|
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||||
|
}
|
||||||
|
validScopes := map[string]bool{
|
||||||
|
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||||
|
}
|
||||||
|
validTiers := map[string]bool{
|
||||||
|
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, rule := range settings.Rules {
|
||||||
|
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||||
|
if tier == "" {
|
||||||
|
tier = OpenAIFastTierAny
|
||||||
|
}
|
||||||
|
if !validTiers[tier] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
|
||||||
|
}
|
||||||
|
settings.Rules[i].ServiceTier = tier
|
||||||
|
if !validActions[rule.Action] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
|
||||||
|
}
|
||||||
|
if !validScopes[rule.Scope] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
|
||||||
|
}
|
||||||
|
for j, pattern := range rule.ModelWhitelist {
|
||||||
|
trimmed := strings.TrimSpace(pattern)
|
||||||
|
if trimmed == "" {
|
||||||
|
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
|
||||||
|
}
|
||||||
|
settings.Rules[i].ModelWhitelist[j] = trimmed
|
||||||
|
}
|
||||||
|
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal openai fast policy settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
|
|||||||
@@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI Fast Policy 策略常量
|
||||||
|
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
|
||||||
|
// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
|
||||||
|
// - "flex":低优先级模式
|
||||||
|
// - 省略:normal 默认
|
||||||
|
//
|
||||||
|
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
|
||||||
|
// anthropic-beta header 换成 body 的 service_tier 字段。
|
||||||
|
const (
|
||||||
|
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
|
||||||
|
OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
|
||||||
|
OpenAIFastTierFlex = "flex" // 仅匹配 flex
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
|
||||||
|
type OpenAIFastPolicyRule struct {
|
||||||
|
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
|
||||||
|
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||||
|
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||||
|
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
|
||||||
|
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
|
||||||
|
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIFastPolicySettings OpenAI fast 策略配置
|
||||||
|
type OpenAIFastPolicySettings struct {
|
||||||
|
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
|
||||||
|
// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
|
||||||
|
// 让上游按 normal 优先级处理。
|
||||||
|
//
|
||||||
|
// 为什么 ModelWhitelist 为空(=对所有模型生效):
|
||||||
|
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
|
||||||
|
// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
|
||||||
|
// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
|
||||||
|
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
|
||||||
|
// 模型,可在 admin UI 中显式配置 model_whitelist。
|
||||||
|
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
|
||||||
|
return &OpenAIFastPolicySettings{
|
||||||
|
Rules: []OpenAIFastPolicyRule{
|
||||||
|
{
|
||||||
|
ServiceTier: OpenAIFastTierPriority,
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ModelWhitelist: []string{},
|
||||||
|
FallbackAction: BetaPolicyActionPass,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
345
backend/internal/service/vertex_service_account.go
Normal file
345
backend/internal/service/vertex_service_account.go
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
vertexDefaultLocation = "us-central1"
|
||||||
|
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
|
||||||
|
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||||
|
vertexServiceAccountCacheSkew = 5 * time.Minute
|
||||||
|
vertexLockWaitTime = 200 * time.Millisecond
|
||||||
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
|
||||||
|
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
|
||||||
|
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type vertexServiceAccountKey struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
PrivateKeyID string `json:"private_key_id"`
|
||||||
|
PrivateKey string `json:"private_key"`
|
||||||
|
ClientEmail string `json:"client_email"`
|
||||||
|
TokenURI string `json:"token_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type vertexTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDesc string `json:"error_description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsVertexServiceAccount() bool {
|
||||||
|
return a != nil && a.Type == AccountTypeServiceAccount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) VertexProjectID() string {
|
||||||
|
if a == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
key, err := parseVertexServiceAccountKey(a)
|
||||||
|
if err == nil {
|
||||||
|
return strings.TrimSpace(key.ProjectID)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) VertexLocation(model string) string {
|
||||||
|
if a == nil {
|
||||||
|
return vertexDefaultLocation
|
||||||
|
}
|
||||||
|
if model != "" && a.Credentials != nil {
|
||||||
|
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
|
||||||
|
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
|
||||||
|
return strings.TrimSpace(loc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return vertexDefaultLocation
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
|
||||||
|
if account == nil || account.Credentials == nil {
|
||||||
|
return nil, errors.New("service account credentials not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
|
||||||
|
return parseVertexServiceAccountJSON([]byte(raw))
|
||||||
|
}
|
||||||
|
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
|
||||||
|
return parseVertexServiceAccountJSON([]byte(raw))
|
||||||
|
}
|
||||||
|
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
|
||||||
|
b, _ := json.Marshal(nested)
|
||||||
|
return parseVertexServiceAccountJSON(b)
|
||||||
|
}
|
||||||
|
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
|
||||||
|
b, _ := json.Marshal(nested)
|
||||||
|
return parseVertexServiceAccountJSON(b)
|
||||||
|
}
|
||||||
|
return nil, errors.New("service_account_json not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
|
||||||
|
var key vertexServiceAccountKey
|
||||||
|
if err := json.Unmarshal(raw, &key); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid service account json: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(key.ClientEmail) == "" {
|
||||||
|
return nil, errors.New("service account json missing client_email")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(key.PrivateKey) == "" {
|
||||||
|
return nil, errors.New("service account json missing private_key")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(key.ProjectID) == "" {
|
||||||
|
return nil, errors.New("service account json missing project_id")
|
||||||
|
}
|
||||||
|
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
|
||||||
|
key.TokenURI = vertexDefaultTokenURL
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
|
||||||
|
fingerprint := ""
|
||||||
|
if key != nil {
|
||||||
|
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
|
||||||
|
fingerprint = hex.EncodeToString(sum[:8])
|
||||||
|
}
|
||||||
|
if fingerprint == "" && account != nil {
|
||||||
|
fingerprint = fmt.Sprintf("account:%d", account.ID)
|
||||||
|
}
|
||||||
|
return "vertex:service_account:" + fingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
|
||||||
|
// using the shared cache and distributed lock to avoid redundant exchanges.
|
||||||
|
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
|
||||||
|
key, err := parseVertexServiceAccountKey(account)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cacheKey := vertexServiceAccountCacheKey(account, key)
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
locked := false
|
||||||
|
if cache != nil {
|
||||||
|
var lockErr error
|
||||||
|
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
} else if lockErr != nil {
|
||||||
|
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
|
} else {
|
||||||
|
time.Sleep(vertexLockWaitTime)
|
||||||
|
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if cache != nil {
|
||||||
|
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
|
}
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
||||||
|
now := time.Now()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": key.ClientEmail,
|
||||||
|
"scope": vertexCloudPlatformScope,
|
||||||
|
"aud": key.TokenURI,
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"exp": now.Add(time.Hour).Unix(),
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
if strings.TrimSpace(key.PrivateKeyID) != "" {
|
||||||
|
token.Header["kid"] = key.PrivateKeyID
|
||||||
|
}
|
||||||
|
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, fmt.Errorf("parse service account private key: %w", err)
|
||||||
|
}
|
||||||
|
assertion, err := token.SignedString(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := url.Values{}
|
||||||
|
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||||
|
values.Set("assertion", assertion)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 15 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
|
var parsed vertexTokenResponse
|
||||||
|
_ = json.Unmarshal(body, &parsed)
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
msg := strings.TrimSpace(parsed.ErrorDesc)
|
||||||
|
if msg == "" {
|
||||||
|
msg = strings.TrimSpace(parsed.Error)
|
||||||
|
}
|
||||||
|
if msg == "" {
|
||||||
|
msg = string(bytes.TrimSpace(body))
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(parsed.AccessToken) == "" {
|
||||||
|
return "", 0, errors.New("service account token response missing access_token")
|
||||||
|
}
|
||||||
|
ttl := time.Duration(parsed.ExpiresIn) * time.Second
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = time.Hour
|
||||||
|
}
|
||||||
|
if ttl > vertexServiceAccountCacheSkew {
|
||||||
|
ttl -= vertexServiceAccountCacheSkew
|
||||||
|
}
|
||||||
|
return parsed.AccessToken, ttl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
location = strings.TrimSpace(location)
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
action = strings.TrimSpace(action)
|
||||||
|
if projectID == "" {
|
||||||
|
return "", errors.New("vertex project_id is required")
|
||||||
|
}
|
||||||
|
if location == "" {
|
||||||
|
location = vertexDefaultLocation
|
||||||
|
}
|
||||||
|
if !vertexLocationPattern.MatchString(location) {
|
||||||
|
return "", fmt.Errorf("invalid vertex location: %s", location)
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
return "", errors.New("vertex model is required")
|
||||||
|
}
|
||||||
|
switch action {
|
||||||
|
case "generateContent", "streamGenerateContent", "countTokens":
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
|
||||||
|
}
|
||||||
|
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
||||||
|
if location == "global" {
|
||||||
|
host = "aiplatform.googleapis.com"
|
||||||
|
}
|
||||||
|
u := fmt.Sprintf(
|
||||||
|
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||||
|
host,
|
||||||
|
url.PathEscape(projectID),
|
||||||
|
url.PathEscape(location),
|
||||||
|
url.PathEscape(model),
|
||||||
|
action,
|
||||||
|
)
|
||||||
|
if stream {
|
||||||
|
u += "?alt=sse"
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
location = strings.TrimSpace(location)
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if projectID == "" {
|
||||||
|
return "", errors.New("vertex project_id is required")
|
||||||
|
}
|
||||||
|
if location == "" {
|
||||||
|
location = vertexDefaultLocation
|
||||||
|
}
|
||||||
|
if !vertexLocationPattern.MatchString(location) {
|
||||||
|
return "", fmt.Errorf("invalid vertex location: %s", location)
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
return "", errors.New("vertex model is required")
|
||||||
|
}
|
||||||
|
action := "rawPredict"
|
||||||
|
if stream {
|
||||||
|
action = "streamRawPredict"
|
||||||
|
}
|
||||||
|
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
||||||
|
if location == "global" {
|
||||||
|
host = "aiplatform.googleapis.com"
|
||||||
|
}
|
||||||
|
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||||
|
host,
|
||||||
|
url.PathEscape(projectID),
|
||||||
|
url.PathEscape(location),
|
||||||
|
escapedModel,
|
||||||
|
action,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVertexAnthropicModelID(model string) string {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
|
||||||
|
return m[1] + "@" + m[2]
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
|
||||||
|
}
|
||||||
|
delete(payload, "model")
|
||||||
|
payload["anthropic_version"] = vertexAnthropicVersion
|
||||||
|
return json.Marshal(payload)
|
||||||
|
}
|
||||||
77
backend/internal/service/vertex_service_account_test.go
Normal file
77
backend/internal/service/vertex_service_account_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildVertexGeminiURL(t *testing.T) {
|
||||||
|
got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
|
||||||
|
got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVertexAnthropicURL(t *testing.T) {
|
||||||
|
got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
|
||||||
|
got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeVertexAnthropicModelID(t *testing.T) {
|
||||||
|
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
|
||||||
|
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVertexAnthropicRequestBody(t *testing.T) {
|
||||||
|
got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "", gjson.GetBytes(got, "model").String())
|
||||||
|
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
|
||||||
|
require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
|
||||||
|
require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
|
||||||
|
_, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid vertex location")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseVertexServiceAccountKey(t *testing.T) {
|
||||||
|
raw := `{
|
||||||
|
"type": "service_account",
|
||||||
|
"project_id": "vertex-proj",
|
||||||
|
"private_key_id": "kid",
|
||||||
|
"private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||||
|
"client_email": "svc@vertex-proj.iam.gserviceaccount.com"
|
||||||
|
}`
|
||||||
|
account := &Account{
|
||||||
|
Type: AccountTypeServiceAccount,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"service_account_json": raw,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
key, err := parseVertexServiceAccountKey(account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "vertex-proj", key.ProjectID)
|
||||||
|
require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
|
||||||
|
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
||||||
|
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
||||||
|
}
|
||||||
@@ -404,12 +404,28 @@ func ProvideBillingCacheService(
|
|||||||
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
|
||||||
|
func ProvideAPIKeyService(
|
||||||
|
apiKeyRepo APIKeyRepository,
|
||||||
|
userRepo UserRepository,
|
||||||
|
groupRepo GroupRepository,
|
||||||
|
userSubRepo UserSubscriptionRepository,
|
||||||
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
|
cache APIKeyCache,
|
||||||
|
cfg *config.Config,
|
||||||
|
billingCacheService *BillingCacheService,
|
||||||
|
) *APIKeyService {
|
||||||
|
svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
|
||||||
|
svc.SetRateLimitCacheInvalidator(billingCacheService)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
NewAuthService,
|
NewAuthService,
|
||||||
NewUserService,
|
NewUserService,
|
||||||
NewAPIKeyService,
|
ProvideAPIKeyService,
|
||||||
ProvideAPIKeyAuthCacheInvalidator,
|
ProvideAPIKeyAuthCacheInvalidator,
|
||||||
NewGroupService,
|
NewGroupService,
|
||||||
NewAccountService,
|
NewAccountService,
|
||||||
|
|||||||
@@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: {
|
|||||||
* @returns Success confirmation
|
* @returns Success confirmation
|
||||||
*/
|
*/
|
||||||
export async function bulkUpdate(
|
export async function bulkUpdate(
|
||||||
accountIds: number[],
|
accountIdsOrPayload: number[] | Record<string, unknown>,
|
||||||
updates: Record<string, unknown>
|
updates?: Record<string, unknown>
|
||||||
): Promise<{
|
): Promise<{
|
||||||
success: number
|
success: number
|
||||||
failed: number
|
failed: number
|
||||||
@@ -379,16 +379,19 @@ export async function bulkUpdate(
|
|||||||
failed_ids?: number[]
|
failed_ids?: number[]
|
||||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||||
}> {
|
}> {
|
||||||
|
const payload = Array.isArray(accountIdsOrPayload)
|
||||||
|
? {
|
||||||
|
account_ids: accountIdsOrPayload,
|
||||||
|
...(updates ?? {})
|
||||||
|
}
|
||||||
|
: accountIdsOrPayload
|
||||||
const { data } = await apiClient.post<{
|
const { data } = await apiClient.post<{
|
||||||
success: number
|
success: number
|
||||||
failed: number
|
failed: number
|
||||||
success_ids?: number[]
|
success_ids?: number[]
|
||||||
failed_ids?: number[]
|
failed_ids?: number[]
|
||||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||||
}>('/admin/accounts/bulk-update', {
|
}>('/admin/accounts/bulk-update', payload)
|
||||||
account_ids: accountIds,
|
|
||||||
...updates
|
|
||||||
})
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -484,6 +484,9 @@ export interface SystemSettings {
|
|||||||
|
|
||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
affiliate_enabled: boolean;
|
affiliate_enabled: boolean;
|
||||||
|
|
||||||
|
// OpenAI fast/flex policy
|
||||||
|
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateSettingsRequest {
|
export interface UpdateSettingsRequest {
|
||||||
@@ -648,6 +651,9 @@ export interface UpdateSettingsRequest {
|
|||||||
|
|
||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
affiliate_enabled?: boolean;
|
affiliate_enabled?: boolean;
|
||||||
|
|
||||||
|
// OpenAI fast/flex policy
|
||||||
|
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -875,6 +881,29 @@ export async function updateRectifierSettings(
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== OpenAI Fast Policy Settings ====================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI fast/flex policy rule interface.
|
||||||
|
* Matches backend dto.OpenAIFastPolicyRule.
|
||||||
|
*/
|
||||||
|
export interface OpenAIFastPolicyRule {
|
||||||
|
service_tier: "all" | "priority" | "flex";
|
||||||
|
action: "pass" | "filter" | "block";
|
||||||
|
scope: "all" | "oauth" | "apikey" | "bedrock";
|
||||||
|
error_message?: string;
|
||||||
|
model_whitelist?: string[];
|
||||||
|
fallback_action?: "pass" | "filter" | "block";
|
||||||
|
fallback_error_message?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI fast/flex policy settings interface.
|
||||||
|
*/
|
||||||
|
export interface OpenAIFastPolicySettings {
|
||||||
|
rules: OpenAIFastPolicyRule[];
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== Beta Policy Settings ====================
|
// ==================== Beta Policy Settings ====================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -332,6 +332,37 @@
|
|||||||
|
|
||||||
<!-- Usage data or unlimited flow -->
|
<!-- Usage data or unlimited flow -->
|
||||||
<div class="space-y-1">
|
<div class="space-y-1">
|
||||||
|
<div
|
||||||
|
v-if="showGeminiTodayStats && todayStats"
|
||||||
|
class="mb-0.5 flex items-center"
|
||||||
|
>
|
||||||
|
<div class="flex items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400">
|
||||||
|
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||||
|
{{ formatKeyRequests }} req
|
||||||
|
</span>
|
||||||
|
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||||
|
{{ formatKeyTokens }}
|
||||||
|
</span>
|
||||||
|
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800" :title="t('usage.accountBilled')">
|
||||||
|
A ${{ formatKeyCost }}
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
v-if="todayStats.user_cost != null"
|
||||||
|
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
|
||||||
|
:title="t('usage.userBilled')"
|
||||||
|
>
|
||||||
|
U ${{ formatKeyUserCost }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-else-if="showGeminiTodayStats && todayStatsLoading"
|
||||||
|
class="mb-0.5 flex items-center gap-1"
|
||||||
|
>
|
||||||
|
<div class="h-3 w-10 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||||
|
<div class="h-3 w-8 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||||
|
<div class="h-3 w-12 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||||
|
</div>
|
||||||
<div v-if="loading" class="space-y-1">
|
<div v-if="loading" class="space-y-1">
|
||||||
<div class="flex items-center gap-1">
|
<div class="flex items-center gap-1">
|
||||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||||
@@ -512,6 +543,10 @@ const shouldFetchUsage = computed(() => {
|
|||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const showGeminiTodayStats = computed(() => {
|
||||||
|
return props.account.platform === 'gemini' && props.account.type === 'service_account'
|
||||||
|
})
|
||||||
|
|
||||||
const geminiUsageAvailable = computed(() => {
|
const geminiUsageAvailable = computed(() => {
|
||||||
return (
|
return (
|
||||||
!!usageInfo.value?.gemini_shared_daily ||
|
!!usageInfo.value?.gemini_shared_daily ||
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||||
/>
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
|
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@
|
|||||||
<svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
<svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||||
</svg>
|
</svg>
|
||||||
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
|
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -227,7 +227,7 @@
|
|||||||
|
|
||||||
<ModelWhitelistSelector
|
<ModelWhitelistSelector
|
||||||
v-model="allowedModels"
|
v-model="allowedModels"
|
||||||
:platforms="selectedPlatforms"
|
:platforms="targetSelectedPlatforms"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
@@ -698,6 +698,87 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- OpenAI OAuth Codex CLI only -->
|
||||||
|
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
|
<div class="mb-3 flex items-center justify-between">
|
||||||
|
<label
|
||||||
|
id="bulk-edit-openai-codex-cli-only-label"
|
||||||
|
class="input-label mb-0"
|
||||||
|
for="bulk-edit-openai-codex-cli-only-enabled"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnly') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="enableCodexCLIOnly"
|
||||||
|
id="bulk-edit-openai-codex-cli-only-enabled"
|
||||||
|
type="checkbox"
|
||||||
|
aria-controls="bulk-edit-openai-codex-cli-only"
|
||||||
|
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="bulk-edit-openai-codex-cli-only"
|
||||||
|
:class="!enableCodexCLIOnly && 'pointer-events-none opacity-50'"
|
||||||
|
>
|
||||||
|
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnlyDesc') }}
|
||||||
|
</p>
|
||||||
|
<button
|
||||||
|
id="bulk-edit-openai-codex-cli-only-toggle"
|
||||||
|
type="button"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
codexCLIOnlyEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
@click="codexCLIOnlyEnabled = !codexCLIOnlyEnabled"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
codexCLIOnlyEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- OpenAI API Key WS mode -->
|
||||||
|
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
|
<div class="mb-3 flex items-center justify-between">
|
||||||
|
<label
|
||||||
|
id="bulk-edit-openai-apikey-ws-mode-label"
|
||||||
|
class="input-label mb-0"
|
||||||
|
for="bulk-edit-openai-apikey-ws-mode-enabled"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.openai.wsMode') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="enableOpenAIAPIKeyWSMode"
|
||||||
|
id="bulk-edit-openai-apikey-ws-mode-enabled"
|
||||||
|
type="checkbox"
|
||||||
|
aria-controls="bulk-edit-openai-apikey-ws-mode"
|
||||||
|
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="bulk-edit-openai-apikey-ws-mode"
|
||||||
|
:class="!enableOpenAIAPIKeyWSMode && 'pointer-events-none opacity-50'"
|
||||||
|
>
|
||||||
|
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||||
|
</p>
|
||||||
|
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t(openAIAPIKeyWSModeConcurrencyHintKey) }}
|
||||||
|
</p>
|
||||||
|
<Select
|
||||||
|
v-model="openaiAPIKeyResponsesWebSocketV2Mode"
|
||||||
|
data-testid="bulk-edit-openai-apikey-ws-mode-select"
|
||||||
|
:options="openAIWSModeOptions"
|
||||||
|
aria-labelledby="bulk-edit-openai-apikey-ws-mode-label"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
|
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
|
||||||
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<div class="mb-3 flex items-center justify-between">
|
<div class="mb-3 flex items-center justify-between">
|
||||||
@@ -933,6 +1014,13 @@ interface Props {
|
|||||||
accountIds: number[]
|
accountIds: number[]
|
||||||
selectedPlatforms: AccountPlatform[]
|
selectedPlatforms: AccountPlatform[]
|
||||||
selectedTypes: AccountType[]
|
selectedTypes: AccountType[]
|
||||||
|
target?: {
|
||||||
|
mode: 'selected' | 'filtered'
|
||||||
|
filters?: Record<string, unknown>
|
||||||
|
previewCount?: number
|
||||||
|
selectedPlatforms?: AccountPlatform[]
|
||||||
|
selectedTypes?: AccountType[]
|
||||||
|
}
|
||||||
proxies: ProxyConfig[]
|
proxies: ProxyConfig[]
|
||||||
groups: AdminGroup[]
|
groups: AdminGroup[]
|
||||||
}
|
}
|
||||||
@@ -947,40 +1035,53 @@ const { t } = useI18n()
|
|||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
|
||||||
// Platform awareness
|
// Platform awareness
|
||||||
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
|
const targetMode = computed(() => props.target?.mode ?? 'selected')
|
||||||
|
const targetPreviewCount = computed(() => props.target?.previewCount ?? props.accountIds.length)
|
||||||
|
const targetSelectedPlatforms = computed(() => props.target?.selectedPlatforms ?? props.selectedPlatforms)
|
||||||
|
const targetSelectedTypes = computed(() => props.target?.selectedTypes ?? props.selectedTypes)
|
||||||
|
const isMixedPlatform = computed(() => targetSelectedPlatforms.value.length > 1)
|
||||||
|
|
||||||
const allOpenAIPassthroughCapable = computed(() => {
|
const allOpenAIPassthroughCapable = computed(() => {
|
||||||
return (
|
return (
|
||||||
props.selectedPlatforms.length === 1 &&
|
targetSelectedPlatforms.value.length === 1 &&
|
||||||
props.selectedPlatforms[0] === 'openai' &&
|
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||||
props.selectedTypes.length > 0 &&
|
targetSelectedTypes.value.length > 0 &&
|
||||||
props.selectedTypes.every(t => t === 'oauth' || t === 'apikey')
|
targetSelectedTypes.value.every(t => t === 'oauth' || t === 'apikey')
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
const allOpenAIOAuth = computed(() => {
|
const allOpenAIOAuth = computed(() => {
|
||||||
return (
|
return (
|
||||||
props.selectedPlatforms.length === 1 &&
|
targetSelectedPlatforms.value.length === 1 &&
|
||||||
props.selectedPlatforms[0] === 'openai' &&
|
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||||
props.selectedTypes.length > 0 &&
|
targetSelectedTypes.value.length > 0 &&
|
||||||
props.selectedTypes.every(t => t === 'oauth')
|
targetSelectedTypes.value.every(t => t === 'oauth')
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
const allOpenAIAPIKey = computed(() => {
|
||||||
|
return (
|
||||||
|
targetSelectedPlatforms.value.length === 1 &&
|
||||||
|
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||||
|
targetSelectedTypes.value.length > 0 &&
|
||||||
|
targetSelectedTypes.value.every(t => t === 'apikey')
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
|
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
|
||||||
const allAnthropicOAuthOrSetupToken = computed(() => {
|
const allAnthropicOAuthOrSetupToken = computed(() => {
|
||||||
return (
|
return (
|
||||||
props.selectedPlatforms.length === 1 &&
|
targetSelectedPlatforms.value.length === 1 &&
|
||||||
props.selectedPlatforms[0] === 'anthropic' &&
|
targetSelectedPlatforms.value[0] === 'anthropic' &&
|
||||||
props.selectedTypes.every(t => t === 'oauth' || t === 'setup-token')
|
targetSelectedTypes.value.every(t => t === 'oauth' || t === 'setup-token')
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredPresets = computed(() => {
|
const filteredPresets = computed(() => {
|
||||||
if (props.selectedPlatforms.length === 0) return []
|
if (targetSelectedPlatforms.value.length === 0) return []
|
||||||
|
|
||||||
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
|
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
|
||||||
for (const platform of props.selectedPlatforms) {
|
for (const platform of targetSelectedPlatforms.value) {
|
||||||
for (const preset of getPresetMappingsByPlatform(platform)) {
|
for (const preset of getPresetMappingsByPlatform(platform)) {
|
||||||
const key = `${preset.from}=>${preset.to}`
|
const key = `${preset.from}=>${preset.to}`
|
||||||
if (!dedupedPresets.has(key)) {
|
if (!dedupedPresets.has(key)) {
|
||||||
@@ -1012,6 +1113,8 @@ const enableStatus = ref(false)
|
|||||||
const enableGroups = ref(false)
|
const enableGroups = ref(false)
|
||||||
const enableOpenAIPassthrough = ref(false)
|
const enableOpenAIPassthrough = ref(false)
|
||||||
const enableOpenAIWSMode = ref(false)
|
const enableOpenAIWSMode = ref(false)
|
||||||
|
const enableOpenAIAPIKeyWSMode = ref(false)
|
||||||
|
const enableCodexCLIOnly = ref(false)
|
||||||
const enableRpmLimit = ref(false)
|
const enableRpmLimit = ref(false)
|
||||||
|
|
||||||
// State - field values
|
// State - field values
|
||||||
@@ -1035,6 +1138,8 @@ const status = ref<'active' | 'inactive'>('active')
|
|||||||
const groupIds = ref<number[]>([])
|
const groupIds = ref<number[]>([])
|
||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
const rpmLimitEnabled = ref(false)
|
const rpmLimitEnabled = ref(false)
|
||||||
const bulkBaseRpm = ref<number | null>(null)
|
const bulkBaseRpm = ref<number | null>(null)
|
||||||
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
|
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
|
||||||
@@ -1076,6 +1181,9 @@ const openAIWSModeOptions = computed(() => [
|
|||||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||||
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
|
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||||
)
|
)
|
||||||
|
const openAIAPIKeyWSModeConcurrencyHintKey = computed(() =>
|
||||||
|
resolveOpenAIWSModeConcurrencyHintKey(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||||
|
)
|
||||||
|
|
||||||
// Model mapping helpers
|
// Model mapping helpers
|
||||||
const addModelMapping = () => {
|
const addModelMapping = () => {
|
||||||
@@ -1254,6 +1362,19 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enableOpenAIAPIKeyWSMode.value) {
|
||||||
|
const extra = ensureExtra()
|
||||||
|
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||||
|
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
|
||||||
|
openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableCodexCLIOnly.value) {
|
||||||
|
const extra = ensureExtra()
|
||||||
|
extra.codex_cli_only = codexCLIOnlyEnabled.value
|
||||||
|
}
|
||||||
|
|
||||||
// RPM limit settings (写入 extra 字段)
|
// RPM limit settings (写入 extra 字段)
|
||||||
if (enableRpmLimit.value) {
|
if (enableRpmLimit.value) {
|
||||||
const extra = ensureExtra()
|
const extra = ensureExtra()
|
||||||
@@ -1291,8 +1412,8 @@ const mixedChannelConfirmed = ref(false)
|
|||||||
const canPreCheck = () =>
|
const canPreCheck = () =>
|
||||||
enableGroups.value &&
|
enableGroups.value &&
|
||||||
groupIds.value.length > 0 &&
|
groupIds.value.length > 0 &&
|
||||||
props.selectedPlatforms.length === 1 &&
|
targetSelectedPlatforms.value.length === 1 &&
|
||||||
(props.selectedPlatforms[0] === 'antigravity' || props.selectedPlatforms[0] === 'anthropic')
|
(targetSelectedPlatforms.value[0] === 'antigravity' || targetSelectedPlatforms.value[0] === 'anthropic')
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
showMixedChannelWarning.value = false
|
showMixedChannelWarning.value = false
|
||||||
@@ -1309,7 +1430,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||||
platform: props.selectedPlatforms[0],
|
platform: targetSelectedPlatforms.value[0],
|
||||||
group_ids: groupIds.value
|
group_ids: groupIds.value
|
||||||
})
|
})
|
||||||
if (!result.has_risk) return true
|
if (!result.has_risk) return true
|
||||||
@@ -1325,7 +1446,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async () => {
|
||||||
if (props.accountIds.length === 0) {
|
if (targetMode.value === 'selected' && props.accountIds.length === 0) {
|
||||||
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
|
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1344,6 +1465,8 @@ const handleSubmit = async () => {
|
|||||||
enableStatus.value ||
|
enableStatus.value ||
|
||||||
enableGroups.value ||
|
enableGroups.value ||
|
||||||
enableOpenAIWSMode.value ||
|
enableOpenAIWSMode.value ||
|
||||||
|
enableOpenAIAPIKeyWSMode.value ||
|
||||||
|
enableCodexCLIOnly.value ||
|
||||||
enableRpmLimit.value ||
|
enableRpmLimit.value ||
|
||||||
userMsgQueueMode.value !== null
|
userMsgQueueMode.value !== null
|
||||||
|
|
||||||
@@ -1373,7 +1496,12 @@ const submitBulkUpdate = async (baseUpdates: Record<string, unknown>) => {
|
|||||||
submitting.value = true
|
submitting.value = true
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
|
const res = targetMode.value === 'filtered' && props.target?.filters
|
||||||
|
? await adminAPI.accounts.bulkUpdate({
|
||||||
|
filters: props.target.filters,
|
||||||
|
...updates
|
||||||
|
})
|
||||||
|
: await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
|
||||||
const success = res.success || 0
|
const success = res.success || 0
|
||||||
const failed = res.failed || 0
|
const failed = res.failed || 0
|
||||||
|
|
||||||
@@ -1437,6 +1565,8 @@ watch(
|
|||||||
enableGroups.value = false
|
enableGroups.value = false
|
||||||
enableOpenAIPassthrough.value = false
|
enableOpenAIPassthrough.value = false
|
||||||
enableOpenAIWSMode.value = false
|
enableOpenAIWSMode.value = false
|
||||||
|
enableOpenAIAPIKeyWSMode.value = false
|
||||||
|
enableCodexCLIOnly.value = false
|
||||||
enableRpmLimit.value = false
|
enableRpmLimit.value = false
|
||||||
|
|
||||||
// Reset all values
|
// Reset all values
|
||||||
@@ -1456,6 +1586,8 @@ watch(
|
|||||||
status.value = 'active'
|
status.value = 'active'
|
||||||
groupIds.value = []
|
groupIds.value = []
|
||||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
|
codexCLIOnlyEnabled.value = false
|
||||||
rpmLimitEnabled.value = false
|
rpmLimitEnabled.value = false
|
||||||
bulkBaseRpm.value = null
|
bulkBaseRpm.value = null
|
||||||
bulkRpmStrategy.value = 'tiered'
|
bulkRpmStrategy.value = 'tiered'
|
||||||
|
|||||||
@@ -153,7 +153,7 @@
|
|||||||
<!-- Account Type Selection (Anthropic) -->
|
<!-- Account Type Selection (Anthropic) -->
|
||||||
<div v-if="form.platform === 'anthropic'">
|
<div v-if="form.platform === 'anthropic'">
|
||||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||||
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
|
<div class="mt-2 grid grid-cols-2 gap-3 sm:grid-cols-4" data-tour="account-form-type">
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="accountCategory = 'oauth-based'"
|
@click="accountCategory = 'oauth-based'"
|
||||||
@@ -244,6 +244,39 @@
|
|||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="accountCategory = 'service_account'"
|
||||||
|
:class="[
|
||||||
|
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||||
|
accountCategory === 'service_account'
|
||||||
|
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
|
||||||
|
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
:class="[
|
||||||
|
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||||
|
accountCategory === 'service_account'
|
||||||
|
? 'bg-sky-500 text-white'
|
||||||
|
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<Icon name="cloud" size="sm" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span class="block text-sm font-medium text-gray-900 dark:text-white">Vertex</span>
|
||||||
|
<span class="text-xs text-gray-500 dark:text-gray-400">Service Account</span>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="accountCategory === 'service_account'"
|
||||||
|
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
|
||||||
|
>
|
||||||
|
<p>{{ t('admin.accounts.vertexAnthropicHint') }}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -302,6 +335,7 @@
|
|||||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.responsesApi') }}</span>
|
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.responsesApi') }}</span>
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -320,7 +354,7 @@
|
|||||||
{{ t('admin.accounts.gemini.helpButton') }}
|
{{ t('admin.accounts.gemini.helpButton') }}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div class="mt-2 grid grid-cols-2 gap-3" data-tour="account-form-type">
|
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="accountCategory = 'oauth-based'"
|
@click="accountCategory = 'oauth-based'"
|
||||||
@@ -392,6 +426,36 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="accountCategory = 'service_account'"
|
||||||
|
:class="[
|
||||||
|
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||||
|
accountCategory === 'service_account'
|
||||||
|
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
|
||||||
|
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
:class="[
|
||||||
|
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||||
|
accountCategory === 'service_account'
|
||||||
|
? 'bg-sky-500 text-white'
|
||||||
|
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<Icon name="cloud" size="sm" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
Vertex
|
||||||
|
</span>
|
||||||
|
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
Service Account
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
@@ -411,6 +475,13 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="accountCategory === 'service_account'"
|
||||||
|
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
|
||||||
|
>
|
||||||
|
<p>{{ t('admin.accounts.vertexGeminiHint') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- OAuth Type Selection (only show when oauth-based is selected) -->
|
<!-- OAuth Type Selection (only show when oauth-based is selected) -->
|
||||||
<div v-if="accountCategory === 'oauth-based'" class="mt-4">
|
<div v-if="accountCategory === 'oauth-based'" class="mt-4">
|
||||||
<label class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</label>
|
<label class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</label>
|
||||||
@@ -610,7 +681,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Tier selection (used as fallback when auto-detection is unavailable/fails) -->
|
<!-- Tier selection (used as fallback when auto-detection is unavailable/fails) -->
|
||||||
<div class="mt-4">
|
<div v-if="accountCategory !== 'service_account'" class="mt-4">
|
||||||
<label class="input-label">{{ t('admin.accounts.gemini.tier.label') }}</label>
|
<label class="input-label">{{ t('admin.accounts.gemini.tier.label') }}</label>
|
||||||
<div class="mt-2">
|
<div class="mt-2">
|
||||||
<select
|
<select
|
||||||
@@ -729,6 +800,96 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Vertex Service Account -->
|
||||||
|
<div v-if="(form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory === 'service_account'" class="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label class="input-label">Service Account JSON</label>
|
||||||
|
<input
|
||||||
|
ref="vertexServiceAccountFileInput"
|
||||||
|
type="file"
|
||||||
|
accept="application/json,.json"
|
||||||
|
class="hidden"
|
||||||
|
@change="handleVertexServiceAccountFile"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
:class="[
|
||||||
|
'rounded-lg border-2 border-dashed px-4 py-5 transition-colors',
|
||||||
|
vertexServiceAccountDragActive
|
||||||
|
? 'border-sky-500 bg-sky-50 dark:border-sky-500 dark:bg-sky-900/20'
|
||||||
|
: 'border-gray-300 bg-gray-50 hover:border-sky-400 hover:bg-sky-50/60 dark:border-dark-500 dark:bg-dark-700/40 dark:hover:border-sky-600 dark:hover:bg-sky-900/10'
|
||||||
|
]"
|
||||||
|
@dragenter.prevent="vertexServiceAccountDragActive = true"
|
||||||
|
@dragover.prevent="vertexServiceAccountDragActive = true"
|
||||||
|
@dragleave.prevent="vertexServiceAccountDragActive = false"
|
||||||
|
@drop.prevent="handleVertexServiceAccountDrop"
|
||||||
|
>
|
||||||
|
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
|
||||||
|
<div class="min-w-0">
|
||||||
|
<div class="flex items-center gap-2 text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
<Icon name="upload" size="sm" />
|
||||||
|
<span>{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}</span>
|
||||||
|
</div>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-secondary shrink-0"
|
||||||
|
@click="vertexServiceAccountFileInput?.click()"
|
||||||
|
>
|
||||||
|
<Icon name="upload" size="sm" />
|
||||||
|
{{ t('admin.accounts.vertexSaJsonSelectBtn') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="vertexClientEmail"
|
||||||
|
class="mt-3 rounded-md border border-sky-200 bg-white px-3 py-2 text-xs text-sky-900 dark:border-sky-800/50 dark:bg-dark-800 dark:text-sky-200"
|
||||||
|
>
|
||||||
|
<div class="truncate">Project ID: <span class="font-mono">{{ vertexProjectId }}</span></div>
|
||||||
|
<div class="truncate">Client Email: <span class="font-mono">{{ vertexClientEmail }}</span></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonUploadHint') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||||
|
<div>
|
||||||
|
<label class="input-label">Project ID</label>
|
||||||
|
<input
|
||||||
|
v-model="vertexProjectId"
|
||||||
|
type="text"
|
||||||
|
class="input font-mono"
|
||||||
|
readonly
|
||||||
|
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">Location</label>
|
||||||
|
<select
|
||||||
|
v-model="vertexLocation"
|
||||||
|
required
|
||||||
|
class="input font-mono"
|
||||||
|
>
|
||||||
|
<optgroup
|
||||||
|
v-for="group in VERTEX_LOCATION_OPTIONS"
|
||||||
|
:key="group.label"
|
||||||
|
:label="group.label"
|
||||||
|
>
|
||||||
|
<option
|
||||||
|
v-for="option in group.options"
|
||||||
|
:key="option.value"
|
||||||
|
:value="option.value"
|
||||||
|
>
|
||||||
|
{{ option.label }}
|
||||||
|
</option>
|
||||||
|
</optgroup>
|
||||||
|
</select>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Antigravity model restriction (applies to OAuth + Upstream) -->
|
<!-- Antigravity model restriction (applies to OAuth + Upstream) -->
|
||||||
<!-- Antigravity 只支持模型映射模式,不支持白名单模式 -->
|
<!-- Antigravity 只支持模型映射模式,不支持白名单模式 -->
|
||||||
<div v-if="form.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="form.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
@@ -2971,6 +3132,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
|||||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
|
||||||
import {
|
import {
|
||||||
OPENAI_WS_MODE_CTX_POOL,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
@@ -3085,7 +3247,7 @@ interface TempUnschedRuleForm {
|
|||||||
// State
|
// State
|
||||||
const step = ref(1)
|
const step = ref(1)
|
||||||
const submitting = ref(false)
|
const submitting = ref(false)
|
||||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
|
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category
|
||||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||||
const apiKeyValue = ref('')
|
const apiKeyValue = ref('')
|
||||||
@@ -3151,6 +3313,12 @@ const bedrockSessionToken = ref('')
|
|||||||
const bedrockRegion = ref('us-east-1')
|
const bedrockRegion = ref('us-east-1')
|
||||||
const bedrockForceGlobal = ref(false)
|
const bedrockForceGlobal = ref(false)
|
||||||
const bedrockApiKeyValue = ref('')
|
const bedrockApiKeyValue = ref('')
|
||||||
|
const vertexServiceAccountFileInput = ref<HTMLInputElement | null>(null)
|
||||||
|
const vertexServiceAccountJson = ref('')
|
||||||
|
const vertexProjectId = ref('')
|
||||||
|
const vertexClientEmail = ref('')
|
||||||
|
const vertexLocation = ref('global')
|
||||||
|
const vertexServiceAccountDragActive = ref(false)
|
||||||
const tempUnschedEnabled = ref(false)
|
const tempUnschedEnabled = ref(false)
|
||||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||||
@@ -3397,7 +3565,7 @@ watch(
|
|||||||
|
|
||||||
// Sync form.type based on accountCategory, addMethod, and platform-specific type
|
// Sync form.type based on accountCategory, addMethod, and platform-specific type
|
||||||
watch(
|
watch(
|
||||||
[accountCategory, addMethod, antigravityAccountType],
|
[accountCategory, addMethod, antigravityAccountType, () => form.platform],
|
||||||
([category, method, agType]) => {
|
([category, method, agType]) => {
|
||||||
// Antigravity upstream 类型(实际创建为 apikey)
|
// Antigravity upstream 类型(实际创建为 apikey)
|
||||||
if (form.platform === 'antigravity' && agType === 'upstream') {
|
if (form.platform === 'antigravity' && agType === 'upstream') {
|
||||||
@@ -3409,7 +3577,9 @@ watch(
|
|||||||
form.type = 'bedrock' as AccountType
|
form.type = 'bedrock' as AccountType
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (category === 'oauth-based') {
|
if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') {
|
||||||
|
form.type = 'service_account' as AccountType
|
||||||
|
} else if (category === 'oauth-based') {
|
||||||
form.type = method as AccountType // 'oauth' or 'setup-token'
|
form.type = method as AccountType // 'oauth' or 'setup-token'
|
||||||
} else {
|
} else {
|
||||||
form.type = 'apikey'
|
form.type = 'apikey'
|
||||||
@@ -3447,6 +3617,12 @@ watch(
|
|||||||
antigravityModelMappings.value = []
|
antigravityModelMappings.value = []
|
||||||
antigravityModelRestrictionMode.value = 'mapping'
|
antigravityModelRestrictionMode.value = 'mapping'
|
||||||
}
|
}
|
||||||
|
if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
|
||||||
|
accountCategory.value = 'oauth-based'
|
||||||
|
}
|
||||||
|
if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') {
|
||||||
|
accountCategory.value = 'oauth-based'
|
||||||
|
}
|
||||||
// Reset Bedrock fields when switching platforms
|
// Reset Bedrock fields when switching platforms
|
||||||
bedrockAccessKeyId.value = ''
|
bedrockAccessKeyId.value = ''
|
||||||
bedrockSecretAccessKey.value = ''
|
bedrockSecretAccessKey.value = ''
|
||||||
@@ -3455,6 +3631,10 @@ watch(
|
|||||||
bedrockForceGlobal.value = false
|
bedrockForceGlobal.value = false
|
||||||
bedrockAuthMode.value = 'sigv4'
|
bedrockAuthMode.value = 'sigv4'
|
||||||
bedrockApiKeyValue.value = ''
|
bedrockApiKeyValue.value = ''
|
||||||
|
vertexServiceAccountJson.value = ''
|
||||||
|
vertexProjectId.value = ''
|
||||||
|
vertexClientEmail.value = ''
|
||||||
|
vertexLocation.value = 'global'
|
||||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||||
interceptWarmupRequests.value = false
|
interceptWarmupRequests.value = false
|
||||||
@@ -3886,6 +4066,10 @@ const resetForm = () => {
|
|||||||
antigravityAccountType.value = 'oauth'
|
antigravityAccountType.value = 'oauth'
|
||||||
upstreamBaseUrl.value = ''
|
upstreamBaseUrl.value = ''
|
||||||
upstreamApiKey.value = ''
|
upstreamApiKey.value = ''
|
||||||
|
vertexServiceAccountJson.value = ''
|
||||||
|
vertexProjectId.value = ''
|
||||||
|
vertexClientEmail.value = ''
|
||||||
|
vertexLocation.value = 'global'
|
||||||
tempUnschedEnabled.value = false
|
tempUnschedEnabled.value = false
|
||||||
tempUnschedRules.value = []
|
tempUnschedRules.value = []
|
||||||
geminiOAuthType.value = 'code_assist'
|
geminiOAuthType.value = 'code_assist'
|
||||||
@@ -4009,6 +4193,52 @@ const normalizePoolModeRetryCount = (value: number) => {
|
|||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const applyVertexServiceAccountJson = (value: string) => {
|
||||||
|
const raw = value.trim()
|
||||||
|
if (!raw) {
|
||||||
|
vertexProjectId.value = ''
|
||||||
|
vertexClientEmail.value = ''
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(raw) as Record<string, unknown>
|
||||||
|
const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : ''
|
||||||
|
const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : ''
|
||||||
|
const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : ''
|
||||||
|
if (!projectId || !clientEmail || !privateKey) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexSaJsonMissingFields'))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
vertexProjectId.value = projectId
|
||||||
|
vertexClientEmail.value = clientEmail
|
||||||
|
vertexServiceAccountJson.value = JSON.stringify(parsed)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
appStore.showError(t('admin.accounts.vertexSaJsonInvalid'))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value)
|
||||||
|
|
||||||
|
const handleVertexServiceAccountFile = async (event: Event) => {
|
||||||
|
const input = event.target as HTMLInputElement
|
||||||
|
const file = input.files?.[0]
|
||||||
|
if (!file) return
|
||||||
|
try {
|
||||||
|
applyVertexServiceAccountJson(await file.text())
|
||||||
|
} finally {
|
||||||
|
input.value = ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleVertexServiceAccountDrop = async (event: DragEvent) => {
|
||||||
|
vertexServiceAccountDragActive.value = false
|
||||||
|
const file = event.dataTransfer?.files?.[0]
|
||||||
|
if (!file) return
|
||||||
|
applyVertexServiceAccountJson(await file.text())
|
||||||
|
}
|
||||||
|
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async () => {
|
||||||
// For OAuth-based type, handle OAuth flow (goes to step 2)
|
// For OAuth-based type, handle OAuth flow (goes to step 2)
|
||||||
if (isOAuthFlow.value) {
|
if (isOAuthFlow.value) {
|
||||||
@@ -4122,6 +4352,29 @@ const handleSubmit = async () => {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') {
|
||||||
|
if (!form.name.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!parseVertexServiceAccountJson()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!vertexLocation.value.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexLocationRequired'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const credentials: Record<string, unknown> = {
|
||||||
|
service_account_json: vertexServiceAccountJson.value.trim(),
|
||||||
|
project_id: vertexProjectId.value.trim(),
|
||||||
|
client_email: vertexClientEmail.value.trim(),
|
||||||
|
location: vertexLocation.value.trim(),
|
||||||
|
tier_id: 'vertex'
|
||||||
|
}
|
||||||
|
await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// For apikey type, create directly
|
// For apikey type, create directly
|
||||||
if (!apiKeyValue.value.trim()) {
|
if (!apiKeyValue.value.trim()) {
|
||||||
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
|
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
|
||||||
|
|||||||
@@ -567,6 +567,221 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Vertex Service Account -->
|
||||||
|
<div v-if="(account.platform === 'gemini' || account.platform === 'anthropic') && account.type === 'service_account'" class="space-y-4">
|
||||||
|
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||||
|
<div>
|
||||||
|
<label class="input-label">Project ID</label>
|
||||||
|
<input
|
||||||
|
v-model="editVertexProjectId"
|
||||||
|
type="text"
|
||||||
|
class="input font-mono"
|
||||||
|
readonly
|
||||||
|
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonEditHint') }}</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">Location</label>
|
||||||
|
<select
|
||||||
|
v-model="editVertexLocation"
|
||||||
|
required
|
||||||
|
class="input font-mono"
|
||||||
|
>
|
||||||
|
<optgroup
|
||||||
|
v-for="group in VERTEX_LOCATION_OPTIONS"
|
||||||
|
:key="group.label"
|
||||||
|
:label="group.label"
|
||||||
|
>
|
||||||
|
<option
|
||||||
|
v-for="option in group.options"
|
||||||
|
:key="option.value"
|
||||||
|
:value="option.value"
|
||||||
|
>
|
||||||
|
{{ option.label }}
|
||||||
|
</option>
|
||||||
|
</optgroup>
|
||||||
|
</select>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Model Restriction Section for Service Account -->
|
||||||
|
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
|
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||||
|
|
||||||
|
<!-- Mode Toggle -->
|
||||||
|
<div class="mb-4 flex gap-2">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="modelRestrictionMode = 'whitelist'"
|
||||||
|
:class="[
|
||||||
|
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||||
|
modelRestrictionMode === 'whitelist'
|
||||||
|
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||||
|
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="mr-1.5 inline h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t('admin.accounts.modelWhitelist') }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="modelRestrictionMode = 'mapping'"
|
||||||
|
:class="[
|
||||||
|
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||||
|
modelRestrictionMode === 'mapping'
|
||||||
|
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||||
|
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="mr-1.5 inline h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M8 7h12m0 0l-4-4m4 4l-4 4m0 6H4m0 0l4 4m-4-4l4-4"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t('admin.accounts.modelMapping') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Whitelist Mode -->
|
||||||
|
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||||
|
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
|
||||||
|
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||||
|
<span v-if="allowedModels.length === 0">{{
|
||||||
|
t('admin.accounts.supportsAllModels')
|
||||||
|
}}</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Mapping Mode -->
|
||||||
|
<div v-else>
|
||||||
|
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
|
||||||
|
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||||
|
<svg
|
||||||
|
class="mr-1 inline h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t('admin.accounts.mapRequestModels') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Model Mapping List -->
|
||||||
|
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
|
<div
|
||||||
|
v-for="(mapping, index) in modelMappings"
|
||||||
|
:key="getModelMappingKey(mapping)"
|
||||||
|
class="flex items-center gap-2"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
v-model="mapping.from"
|
||||||
|
type="text"
|
||||||
|
class="input flex-1"
|
||||||
|
:placeholder="t('admin.accounts.requestModel')"
|
||||||
|
/>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4 flex-shrink-0 text-gray-400"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M14 5l7 7m0 0l-7 7m7-7H3"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<input
|
||||||
|
v-model="mapping.to"
|
||||||
|
type="text"
|
||||||
|
class="input flex-1"
|
||||||
|
:placeholder="t('admin.accounts.actualModel')"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="removeModelMapping(index)"
|
||||||
|
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||||
|
>
|
||||||
|
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="addModelMapping"
|
||||||
|
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="mr-1 inline h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M12 4v16m8-8H4"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t('admin.accounts.addMapping') }}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- Quick Add Buttons -->
|
||||||
|
<div class="flex flex-wrap gap-2">
|
||||||
|
<button
|
||||||
|
v-for="preset in presetMappings"
|
||||||
|
:key="preset.label"
|
||||||
|
type="button"
|
||||||
|
@click="addPresetMapping(preset.from, preset.to)"
|
||||||
|
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||||
|
>
|
||||||
|
+ {{ preset.label }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
|
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
|
||||||
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
||||||
<!-- SigV4 fields -->
|
<!-- SigV4 fields -->
|
||||||
@@ -1919,6 +2134,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
|||||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||||
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
|
||||||
import {
|
import {
|
||||||
OPENAI_WS_MODE_CTX_POOL,
|
OPENAI_WS_MODE_CTX_POOL,
|
||||||
OPENAI_WS_MODE_OFF,
|
OPENAI_WS_MODE_OFF,
|
||||||
@@ -1987,6 +2203,9 @@ const editBedrockSessionToken = ref('')
|
|||||||
const editBedrockRegion = ref('')
|
const editBedrockRegion = ref('')
|
||||||
const editBedrockForceGlobal = ref(false)
|
const editBedrockForceGlobal = ref(false)
|
||||||
const editBedrockApiKeyValue = ref('')
|
const editBedrockApiKeyValue = ref('')
|
||||||
|
const editVertexProjectId = ref('')
|
||||||
|
const editVertexClientEmail = ref('')
|
||||||
|
const editVertexLocation = ref('us-central1')
|
||||||
const isBedrockAPIKeyMode = computed(() =>
|
const isBedrockAPIKeyMode = computed(() =>
|
||||||
props.account?.type === 'bedrock' &&
|
props.account?.type === 'bedrock' &&
|
||||||
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
|
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
|
||||||
@@ -2246,6 +2465,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
||||||
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
|
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
|
||||||
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
|
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
|
||||||
|
editVertexProjectId.value = ''
|
||||||
|
editVertexClientEmail.value = ''
|
||||||
|
editVertexLocation.value = 'us-central1'
|
||||||
|
|
||||||
// Load mixed scheduling setting (only for antigravity accounts)
|
// Load mixed scheduling setting (only for antigravity accounts)
|
||||||
mixedScheduling.value = false
|
mixedScheduling.value = false
|
||||||
@@ -2467,6 +2689,31 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
||||||
const credentials = newAccount.credentials as Record<string, unknown>
|
const credentials = newAccount.credentials as Record<string, unknown>
|
||||||
editBaseUrl.value = (credentials.base_url as string) || ''
|
editBaseUrl.value = (credentials.base_url as string) || ''
|
||||||
|
} else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) {
|
||||||
|
const credentials = newAccount.credentials as Record<string, unknown>
|
||||||
|
editVertexProjectId.value = (credentials.project_id as string) || ''
|
||||||
|
editVertexClientEmail.value = (credentials.client_email as string) || ''
|
||||||
|
editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
|
||||||
|
|
||||||
|
// Load model mappings for service_account
|
||||||
|
const existingMappings = credentials.model_mapping as Record<string, string> | undefined
|
||||||
|
if (existingMappings && typeof existingMappings === 'object') {
|
||||||
|
const entries = Object.entries(existingMappings)
|
||||||
|
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||||
|
if (isWhitelistMode) {
|
||||||
|
modelRestrictionMode.value = 'whitelist'
|
||||||
|
allowedModels.value = entries.map(([from]) => from)
|
||||||
|
modelMappings.value = []
|
||||||
|
} else {
|
||||||
|
modelRestrictionMode.value = 'mapping'
|
||||||
|
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||||
|
allowedModels.value = []
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
modelRestrictionMode.value = 'whitelist'
|
||||||
|
modelMappings.value = []
|
||||||
|
allowedModels.value = []
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
const platformDefaultUrl =
|
const platformDefaultUrl =
|
||||||
newAccount.platform === 'openai'
|
newAccount.platform === 'openai'
|
||||||
@@ -3057,6 +3304,46 @@ const handleSubmit = async () => {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatePayload.credentials = newCredentials
|
||||||
|
} else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') {
|
||||||
|
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||||
|
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||||
|
|
||||||
|
if (!editVertexProjectId.value.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!editVertexClientEmail.value.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!editVertexLocation.value.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexLocationRequired'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
|
||||||
|
appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newCredentials.project_id = editVertexProjectId.value.trim()
|
||||||
|
newCredentials.client_email = editVertexClientEmail.value.trim()
|
||||||
|
newCredentials.location = editVertexLocation.value.trim()
|
||||||
|
newCredentials.tier_id = 'vertex'
|
||||||
|
|
||||||
|
// Add model mapping if configured
|
||||||
|
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||||
|
if (modelMapping) {
|
||||||
|
newCredentials.model_mapping = modelMapping
|
||||||
|
} else {
|
||||||
|
delete newCredentials.model_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||||
|
if (!applyTempUnschedConfig(newCredentials)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
updatePayload.credentials = newCredentials
|
updatePayload.credentials = newCredentials
|
||||||
} else if (props.account.type === 'bedrock') {
|
} else if (props.account.type === 'bedrock') {
|
||||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||||
|
|||||||
@@ -57,6 +57,19 @@ function makeAccount(overrides: Partial<Account>): Account {
|
|||||||
describe('AccountUsageCell', () => {
|
describe('AccountUsageCell', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
getUsage.mockReset()
|
getUsage.mockReset()
|
||||||
|
Object.defineProperty(window, 'matchMedia', {
|
||||||
|
writable: true,
|
||||||
|
value: vi.fn().mockImplementation(() => ({
|
||||||
|
matches: true,
|
||||||
|
media: '(min-width: 768px)',
|
||||||
|
onchange: null,
|
||||||
|
addListener: vi.fn(),
|
||||||
|
removeListener: vi.fn(),
|
||||||
|
addEventListener: vi.fn(),
|
||||||
|
removeEventListener: vi.fn(),
|
||||||
|
dispatchEvent: vi.fn(),
|
||||||
|
}))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
|
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
|
||||||
@@ -603,4 +616,43 @@ describe('AccountUsageCell', () => {
|
|||||||
|
|
||||||
expect(wrapper.text().trim()).toBe('-')
|
expect(wrapper.text().trim()).toBe('-')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => {
|
||||||
|
const wrapper = mount(AccountUsageCell, {
|
||||||
|
props: {
|
||||||
|
account: makeAccount({
|
||||||
|
id: 4001,
|
||||||
|
platform: 'gemini',
|
||||||
|
type: 'service_account',
|
||||||
|
credentials: {
|
||||||
|
tier_id: 'vertex',
|
||||||
|
project_id: 'vertex-proj',
|
||||||
|
client_email: 'svc@vertex-proj.iam.gserviceaccount.com',
|
||||||
|
location: 'global'
|
||||||
|
},
|
||||||
|
extra: {}
|
||||||
|
}),
|
||||||
|
todayStats: {
|
||||||
|
requests: 0,
|
||||||
|
tokens: 0,
|
||||||
|
cost: 0,
|
||||||
|
standard_cost: 0,
|
||||||
|
user_cost: 0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
UsageProgressBar: true,
|
||||||
|
AccountQuotaInfo: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(wrapper.text()).toContain('0 req')
|
||||||
|
expect(wrapper.text()).toContain('0')
|
||||||
|
expect(wrapper.text()).toContain('A $0.00')
|
||||||
|
expect(wrapper.text()).toContain('U $0.00')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -178,6 +178,45 @@ describe('BulkEditAccountModal', () => {
|
|||||||
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
|
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('OpenAI OAuth 批量编辑应提交 codex_cli_only 字段', async () => {
|
||||||
|
const wrapper = mountModal({
|
||||||
|
selectedPlatforms: ['openai'],
|
||||||
|
selectedTypes: ['oauth']
|
||||||
|
})
|
||||||
|
|
||||||
|
await wrapper.get('#bulk-edit-openai-codex-cli-only-enabled').setValue(true)
|
||||||
|
await wrapper.get('#bulk-edit-openai-codex-cli-only-toggle').trigger('click')
|
||||||
|
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
|
||||||
|
extra: {
|
||||||
|
codex_cli_only: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
|
||||||
|
const wrapper = mountModal({
|
||||||
|
selectedPlatforms: ['openai'],
|
||||||
|
selectedTypes: ['apikey']
|
||||||
|
})
|
||||||
|
|
||||||
|
await wrapper.get('#bulk-edit-openai-apikey-ws-mode-enabled').setValue(true)
|
||||||
|
await wrapper.get('[data-testid="bulk-edit-openai-apikey-ws-mode-select"]').setValue('ctx_pool')
|
||||||
|
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
|
||||||
|
extra: {
|
||||||
|
openai_apikey_responses_websockets_v2_mode: 'ctx_pool',
|
||||||
|
openai_apikey_responses_websockets_v2_enabled: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
|
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
|
||||||
const wrapper = mountModal({
|
const wrapper = mountModal({
|
||||||
selectedPlatforms: ['openai'],
|
selectedPlatforms: ['openai'],
|
||||||
@@ -217,4 +256,41 @@ describe('BulkEditAccountModal', () => {
|
|||||||
})
|
})
|
||||||
expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough')
|
expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('filtered-results 模式下应提交 filters 而不是 account_ids', async () => {
|
||||||
|
const wrapper = mountModal({
|
||||||
|
accountIds: [],
|
||||||
|
target: {
|
||||||
|
mode: 'filtered',
|
||||||
|
filters: {
|
||||||
|
platform: 'openai',
|
||||||
|
type: 'oauth',
|
||||||
|
status: 'active',
|
||||||
|
group: '12',
|
||||||
|
search: 'bulk-target',
|
||||||
|
privacy_mode: 'training_set_cf_blocked'
|
||||||
|
},
|
||||||
|
previewCount: 5,
|
||||||
|
selectedPlatforms: ['openai'],
|
||||||
|
selectedTypes: ['oauth']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
await wrapper.get('#bulk-edit-status-enabled').setValue(true)
|
||||||
|
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
|
||||||
|
filters: {
|
||||||
|
platform: 'openai',
|
||||||
|
type: 'oauth',
|
||||||
|
status: 'active',
|
||||||
|
group: '12',
|
||||||
|
search: 'bulk-target',
|
||||||
|
privacy_mode: 'training_set_cf_blocked'
|
||||||
|
},
|
||||||
|
status: 'active'
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
<template>
|
<template>
|
||||||
<div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg dark:bg-primary-900/20">
|
<div class="mb-4 flex items-center justify-between rounded-lg bg-primary-50 p-3 dark:bg-primary-900/20">
|
||||||
<div class="flex flex-wrap items-center gap-2">
|
<div class="flex flex-wrap items-center gap-2">
|
||||||
<span class="text-sm font-medium text-primary-900 dark:text-primary-100">
|
<span v-if="selectedIds.length > 0" class="text-sm font-medium text-primary-900 dark:text-primary-100">
|
||||||
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
|
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
|
||||||
</span>
|
</span>
|
||||||
|
<span v-else class="text-sm font-medium text-primary-900 dark:text-primary-100">
|
||||||
|
{{ t('admin.accounts.bulkEdit.title') }}
|
||||||
|
</span>
|
||||||
|
<template v-if="selectedIds.length > 0">
|
||||||
<button
|
<button
|
||||||
@click="$emit('select-page')"
|
@click="$emit('select-page')"
|
||||||
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
||||||
@@ -17,19 +21,25 @@
|
|||||||
>
|
>
|
||||||
{{ t('admin.accounts.bulkActions.clear') }}
|
{{ t('admin.accounts.bulkActions.clear') }}
|
||||||
</button>
|
</button>
|
||||||
|
</template>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
<template v-if="selectedIds.length > 0">
|
||||||
<button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button>
|
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
||||||
<button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button>
|
<button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button>
|
||||||
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
|
<button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button>
|
||||||
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
|
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
|
||||||
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
|
||||||
|
<button @click="$emit('edit-selected')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
||||||
|
</template>
|
||||||
|
<button @click="$emit('edit-filtered')" class="btn btn-primary btn-sm">
|
||||||
|
{{ t('admin.accounts.bulkEdit.submit') }}
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n()
|
defineProps(['selectedIds']); defineEmits(['delete', 'edit-selected', 'edit-filtered', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n()
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -25,6 +25,7 @@
|
|||||||
<!-- Setup Token icon -->
|
<!-- Setup Token icon -->
|
||||||
<Icon v-else-if="type === 'setup-token'" name="shield" size="xs" />
|
<Icon v-else-if="type === 'setup-token'" name="shield" size="xs" />
|
||||||
<!-- API Key icon -->
|
<!-- API Key icon -->
|
||||||
|
<Icon v-else-if="type === 'service_account'" name="cloud" size="xs" />
|
||||||
<Icon v-else name="key" size="xs" />
|
<Icon v-else name="key" size="xs" />
|
||||||
<span>{{ typeLabel }}</span>
|
<span>{{ typeLabel }}</span>
|
||||||
</span>
|
</span>
|
||||||
@@ -88,6 +89,8 @@ const typeLabel = computed(() => {
|
|||||||
return 'Key'
|
return 'Key'
|
||||||
case 'bedrock':
|
case 'bedrock':
|
||||||
return 'AWS'
|
return 'AWS'
|
||||||
|
case 'service_account':
|
||||||
|
return 'Vertex'
|
||||||
default:
|
default:
|
||||||
return props.type
|
return props.type
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,3 +13,51 @@ export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOT
|
|||||||
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
|
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
|
||||||
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
|
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
|
||||||
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
|
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
|
||||||
|
|
||||||
|
/** Vertex AI location options for Service Account accounts */
|
||||||
|
export const VERTEX_LOCATION_OPTIONS = [
|
||||||
|
{
|
||||||
|
label: 'Common',
|
||||||
|
options: [
|
||||||
|
{ value: 'us-central1', label: 'us-central1 (Iowa)' },
|
||||||
|
{ value: 'global', label: 'global' },
|
||||||
|
{ value: 'us', label: 'us' },
|
||||||
|
{ value: 'eu', label: 'eu' }
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'United States',
|
||||||
|
options: [
|
||||||
|
{ value: 'us-east1', label: 'us-east1 (South Carolina)' },
|
||||||
|
{ value: 'us-east4', label: 'us-east4 (Northern Virginia)' },
|
||||||
|
{ value: 'us-east5', label: 'us-east5 (Columbus)' },
|
||||||
|
{ value: 'us-south1', label: 'us-south1 (Dallas)' },
|
||||||
|
{ value: 'us-west1', label: 'us-west1 (Oregon)' },
|
||||||
|
{ value: 'us-west4', label: 'us-west4 (Las Vegas)' }
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Europe',
|
||||||
|
options: [
|
||||||
|
{ value: 'europe-west1', label: 'europe-west1 (Belgium)' },
|
||||||
|
{ value: 'europe-west2', label: 'europe-west2 (London)' },
|
||||||
|
{ value: 'europe-west3', label: 'europe-west3 (Frankfurt)' },
|
||||||
|
{ value: 'europe-west4', label: 'europe-west4 (Netherlands)' },
|
||||||
|
{ value: 'europe-west6', label: 'europe-west6 (Zurich)' },
|
||||||
|
{ value: 'europe-west8', label: 'europe-west8 (Milan)' },
|
||||||
|
{ value: 'europe-west9', label: 'europe-west9 (Paris)' }
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Asia Pacific',
|
||||||
|
options: [
|
||||||
|
{ value: 'asia-east1', label: 'asia-east1 (Taiwan)' },
|
||||||
|
{ value: 'asia-east2', label: 'asia-east2 (Hong Kong)' },
|
||||||
|
{ value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' },
|
||||||
|
{ value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' },
|
||||||
|
{ value: 'asia-south1', label: 'asia-south1 (Mumbai)' },
|
||||||
|
{ value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' },
|
||||||
|
{ value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
] as const
|
||||||
|
|||||||
@@ -2815,6 +2815,26 @@ export default {
|
|||||||
claudeConsole: 'Claude Console',
|
claudeConsole: 'Claude Console',
|
||||||
bedrockLabel: 'AWS Bedrock',
|
bedrockLabel: 'AWS Bedrock',
|
||||||
bedrockDesc: 'SigV4 / API Key',
|
bedrockDesc: 'SigV4 / API Key',
|
||||||
|
vertexLabel: 'Vertex',
|
||||||
|
vertexDesc: 'Service Account',
|
||||||
|
vertexAnthropicHint: 'Use a Google Cloud Service Account JSON to call Anthropic Claude via Vertex AI. It is recommended to configure model mapping to map client Claude model names to Vertex model IDs.',
|
||||||
|
vertexGeminiHint: 'Use a Google Cloud Service Account JSON to access Vertex AI Gemini. It is recommended to place Vertex accounts in a separate group to avoid mixing with AI Studio/Gemini OAuth on the same models.',
|
||||||
|
vertexSaJsonLabel: 'Service Account JSON',
|
||||||
|
vertexSaJsonLoaded: 'Service Account JSON loaded',
|
||||||
|
vertexSaJsonDrop: 'Drop Service Account JSON here',
|
||||||
|
vertexSaJsonKeyHidden: 'Key content is not displayed in the form.',
|
||||||
|
vertexSaJsonDropHint: 'Drag a .json file here, or click the button to select one.',
|
||||||
|
vertexSaJsonSelectBtn: 'Select JSON',
|
||||||
|
vertexSaJsonUploadHint: 'After uploading or dropping a JSON file, the project_id will be auto-extracted. Key content is only used for account creation.',
|
||||||
|
vertexSaJsonEditHint: 'Service Account JSON is not shown on the edit page; to change the JSON, delete the account and recreate it.',
|
||||||
|
vertexProjectIdPlaceholder: 'Auto-extracted from JSON',
|
||||||
|
vertexLocationHint: 'Available locations vary by Vertex model. Select the default endpoint location for this account.',
|
||||||
|
vertexLocationRequired: 'Please enter a Vertex location',
|
||||||
|
vertexSaJsonMissingFields: 'Service Account JSON is missing project_id, client_email, or private_key',
|
||||||
|
vertexSaJsonMissingProjectId: 'Service Account JSON is missing project_id',
|
||||||
|
vertexSaJsonMissingClientEmail: 'Service Account JSON is missing client_email',
|
||||||
|
vertexSaJsonInvalid: 'Service Account JSON format is invalid',
|
||||||
|
vertexSaJsonRequired: 'Please upload a Service Account JSON',
|
||||||
oauthSetupToken: 'OAuth / Setup Token',
|
oauthSetupToken: 'OAuth / Setup Token',
|
||||||
addMethod: 'Add Method',
|
addMethod: 'Add Method',
|
||||||
setupTokenLongLived: 'Setup Token (Long-lived)',
|
setupTokenLongLived: 'Setup Token (Long-lived)',
|
||||||
@@ -4648,7 +4668,7 @@ export default {
|
|||||||
errorLogRetentionDays: 'Error Log Retention Days',
|
errorLogRetentionDays: 'Error Log Retention Days',
|
||||||
minuteMetricsRetentionDays: 'Minute Metrics Retention Days',
|
minuteMetricsRetentionDays: 'Minute Metrics Retention Days',
|
||||||
hourlyMetricsRetentionDays: 'Hourly Metrics Retention Days',
|
hourlyMetricsRetentionDays: 'Hourly Metrics Retention Days',
|
||||||
retentionDaysHint: 'Recommended 7-90 days, longer periods will consume more storage',
|
retentionDaysHint: 'Recommended 7-90 days; longer periods consume more storage. Set to 0 to wipe all history on every scheduled cleanup',
|
||||||
aggregation: 'Pre-aggregation Tasks',
|
aggregation: 'Pre-aggregation Tasks',
|
||||||
enableAggregation: 'Enable Pre-aggregation',
|
enableAggregation: 'Enable Pre-aggregation',
|
||||||
aggregationHint: 'Pre-aggregation improves query performance for long time windows',
|
aggregationHint: 'Pre-aggregation improves query performance for long time windows',
|
||||||
@@ -4678,7 +4698,7 @@ export default {
|
|||||||
autoRefreshCountdown: 'Auto refresh: {seconds}s',
|
autoRefreshCountdown: 'Auto refresh: {seconds}s',
|
||||||
validation: {
|
validation: {
|
||||||
title: 'Please fix the following issues',
|
title: 'Please fix the following issues',
|
||||||
retentionDaysRange: 'Retention days must be between 1-365 days',
|
retentionDaysRange: 'Retention days must be between 0 and 365 (0 = wipe all on every cleanup)',
|
||||||
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
|
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
|
||||||
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
|
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
|
||||||
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
|
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
|
||||||
@@ -5535,6 +5555,38 @@ export default {
|
|||||||
presetOpusOnlyDesc: 'Pass for Opus, filter others',
|
presetOpusOnlyDesc: 'Pass for Opus, filter others',
|
||||||
commonPatterns: 'Common patterns'
|
commonPatterns: 'Common patterns'
|
||||||
},
|
},
|
||||||
|
openaiFastPolicy: {
|
||||||
|
title: 'OpenAI Fast/Flex Policy',
|
||||||
|
description: 'Intercept, filter, or pass OpenAI fast(priority) / flex requests based on the request body service_tier field. Applies to the OpenAI gateway only.',
|
||||||
|
empty: 'No rules configured. Click the button below to add one.',
|
||||||
|
ruleHeader: 'Rule #{index}',
|
||||||
|
removeRule: 'Remove rule',
|
||||||
|
addRule: 'Add rule',
|
||||||
|
saveHint: 'Saved together with system settings (click the global Save button at the bottom of the page).',
|
||||||
|
serviceTier: 'service_tier match',
|
||||||
|
tierAll: 'All tiers',
|
||||||
|
tierPriority: 'priority (fast)',
|
||||||
|
tierFlex: 'flex',
|
||||||
|
action: 'Action',
|
||||||
|
actionPass: 'Pass (keep service_tier)',
|
||||||
|
actionFilter: 'Filter (remove service_tier)',
|
||||||
|
actionBlock: 'Block (reject request)',
|
||||||
|
scope: 'Scope',
|
||||||
|
scopeAll: 'All accounts',
|
||||||
|
scopeOAuth: 'OAuth only',
|
||||||
|
scopeAPIKey: 'API Key only',
|
||||||
|
scopeBedrock: 'Bedrock only',
|
||||||
|
errorMessage: 'Error message',
|
||||||
|
errorMessagePlaceholder: 'Custom error message when blocked',
|
||||||
|
errorMessageHint: 'Leave empty for the default message.',
|
||||||
|
modelWhitelist: 'Model whitelist',
|
||||||
|
modelWhitelistHint: 'Leave empty to apply to all models. Supports exact match and wildcard prefix (e.g., gpt-5.5*).',
|
||||||
|
modelPatternPlaceholder: 'e.g., gpt-5.5 or gpt-5.5*',
|
||||||
|
addModelPattern: 'Add model pattern',
|
||||||
|
fallbackAction: 'Fallback action',
|
||||||
|
fallbackActionHint: 'Action for models not matching the whitelist.',
|
||||||
|
fallbackErrorMessagePlaceholder: 'Custom error message when non-whitelisted models are blocked'
|
||||||
|
},
|
||||||
wechatConnect: {
|
wechatConnect: {
|
||||||
title: 'WeChat Connect',
|
title: 'WeChat Connect',
|
||||||
description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.',
|
description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.',
|
||||||
|
|||||||
@@ -2963,6 +2963,26 @@ export default {
|
|||||||
claudeConsole: 'Claude Console',
|
claudeConsole: 'Claude Console',
|
||||||
bedrockLabel: 'AWS Bedrock',
|
bedrockLabel: 'AWS Bedrock',
|
||||||
bedrockDesc: 'SigV4 / API Key',
|
bedrockDesc: 'SigV4 / API Key',
|
||||||
|
vertexLabel: 'Vertex',
|
||||||
|
vertexDesc: 'Service Account',
|
||||||
|
vertexAnthropicHint: '使用 Google Cloud Service Account JSON 通过 Vertex AI 调用 Anthropic Claude。建议配置模型映射,将客户端 Claude 模型名映射到 Vertex 模型 ID。',
|
||||||
|
vertexGeminiHint: '使用 Google Cloud Service Account JSON 访问 Vertex AI Gemini。建议将 Vertex 账号放入独立分组,避免和 AI Studio/Gemini OAuth 同模型混调。',
|
||||||
|
vertexSaJsonLabel: 'Service Account JSON',
|
||||||
|
vertexSaJsonLoaded: '已读取 Service Account JSON',
|
||||||
|
vertexSaJsonDrop: '拖入 Service Account JSON',
|
||||||
|
vertexSaJsonKeyHidden: '密钥内容不会在表单中显示。',
|
||||||
|
vertexSaJsonDropHint: '把 .json 文件拖到这里,或点击按钮选择文件。',
|
||||||
|
vertexSaJsonSelectBtn: '选择 JSON',
|
||||||
|
vertexSaJsonUploadHint: '上传或拖入 JSON 后会自动读取 project_id,密钥内容仅用于创建账号提交。',
|
||||||
|
vertexSaJsonEditHint: 'Service Account JSON 不在编辑页显示;需要更换 JSON 时请删除账号后重新创建。',
|
||||||
|
vertexProjectIdPlaceholder: '从 JSON 自动读取',
|
||||||
|
vertexLocationHint: '不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。',
|
||||||
|
vertexLocationRequired: '请填写 Vertex location',
|
||||||
|
vertexSaJsonMissingFields: 'Service Account JSON 缺少 project_id、client_email 或 private_key',
|
||||||
|
vertexSaJsonMissingProjectId: 'Service Account JSON 缺少 project_id',
|
||||||
|
vertexSaJsonMissingClientEmail: 'Service Account JSON 缺少 client_email',
|
||||||
|
vertexSaJsonInvalid: 'Service Account JSON 格式无效',
|
||||||
|
vertexSaJsonRequired: '请上传 Service Account JSON',
|
||||||
oauthSetupToken: 'OAuth / Setup Token',
|
oauthSetupToken: 'OAuth / Setup Token',
|
||||||
addMethod: '添加方式',
|
addMethod: '添加方式',
|
||||||
setupTokenLongLived: 'Setup Token(长期有效)',
|
setupTokenLongLived: 'Setup Token(长期有效)',
|
||||||
@@ -4810,7 +4830,7 @@ export default {
|
|||||||
errorLogRetentionDays: '错误日志保留天数',
|
errorLogRetentionDays: '错误日志保留天数',
|
||||||
minuteMetricsRetentionDays: '分钟指标保留天数',
|
minuteMetricsRetentionDays: '分钟指标保留天数',
|
||||||
hourlyMetricsRetentionDays: '小时指标保留天数',
|
hourlyMetricsRetentionDays: '小时指标保留天数',
|
||||||
retentionDaysHint: '建议保留7-90天,过长会占用存储空间',
|
retentionDaysHint: '建议保留 7-90 天,过长会占用存储空间;填 0 表示每次定时清理时清空所有历史',
|
||||||
aggregation: '预聚合任务',
|
aggregation: '预聚合任务',
|
||||||
enableAggregation: '启用预聚合任务',
|
enableAggregation: '启用预聚合任务',
|
||||||
aggregationHint: '预聚合可提升长时间窗口查询性能',
|
aggregationHint: '预聚合可提升长时间窗口查询性能',
|
||||||
@@ -4841,7 +4861,7 @@ export default {
|
|||||||
autoRefreshCountdown: '自动刷新:{seconds}s',
|
autoRefreshCountdown: '自动刷新:{seconds}s',
|
||||||
validation: {
|
validation: {
|
||||||
title: '请先修正以下问题',
|
title: '请先修正以下问题',
|
||||||
retentionDaysRange: '保留天数必须在1-365天之间',
|
retentionDaysRange: '保留天数必须在 0-365 天之间(0 = 每次清理时清空所有)',
|
||||||
slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
|
slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
|
||||||
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
|
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
|
||||||
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
|
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
|
||||||
@@ -5695,6 +5715,38 @@ export default {
|
|||||||
presetOpusOnlyDesc: 'Opus 透传,其他模型过滤',
|
presetOpusOnlyDesc: 'Opus 透传,其他模型过滤',
|
||||||
commonPatterns: '常用模式'
|
commonPatterns: '常用模式'
|
||||||
},
|
},
|
||||||
|
openaiFastPolicy: {
|
||||||
|
title: 'OpenAI Fast/Flex 策略',
|
||||||
|
description: '基于请求体 service_tier 字段拦截/过滤/透传 OpenAI fast(priority) 与 flex 请求;仅作用于 OpenAI 网关。',
|
||||||
|
empty: '尚未配置任何规则。点击下方按钮新增。',
|
||||||
|
ruleHeader: '规则 #{index}',
|
||||||
|
removeRule: '删除规则',
|
||||||
|
addRule: '新增规则',
|
||||||
|
saveHint: '保存时随系统设置一起提交(点击页面底部「保存」按钮)。',
|
||||||
|
serviceTier: 'service_tier 匹配',
|
||||||
|
tierAll: '全部 tier',
|
||||||
|
tierPriority: 'priority(fast)',
|
||||||
|
tierFlex: 'flex',
|
||||||
|
action: '处理方式',
|
||||||
|
actionPass: '透传(保留 service_tier)',
|
||||||
|
actionFilter: '过滤(移除 service_tier)',
|
||||||
|
actionBlock: '拦截(拒绝请求)',
|
||||||
|
scope: '生效范围',
|
||||||
|
scopeAll: '全部账号',
|
||||||
|
scopeOAuth: '仅 OAuth 账号',
|
||||||
|
scopeAPIKey: '仅 API Key 账号',
|
||||||
|
scopeBedrock: '仅 Bedrock 账号',
|
||||||
|
errorMessage: '错误消息',
|
||||||
|
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
|
||||||
|
errorMessageHint: '留空则使用默认错误消息。',
|
||||||
|
modelWhitelist: '模型白名单',
|
||||||
|
modelWhitelistHint: '留空表示对所有模型生效;支持精确匹配与通配符(如 gpt-5.5*)。',
|
||||||
|
modelPatternPlaceholder: '例如: gpt-5.5 或 gpt-5.5*',
|
||||||
|
addModelPattern: '添加模型规则',
|
||||||
|
fallbackAction: '未匹配模型处理方式',
|
||||||
|
fallbackActionHint: '当请求模型不在白名单中时的处理方式。',
|
||||||
|
fallbackErrorMessagePlaceholder: '未匹配模型被拦截时返回的自定义错误消息'
|
||||||
|
},
|
||||||
wechatConnect: {
|
wechatConnect: {
|
||||||
title: '微信登录',
|
title: '微信登录',
|
||||||
description: '用于微信开放平台或公众号/小程序的第三方登录配置。',
|
description: '用于微信开放平台或公众号/小程序的第三方登录配置。',
|
||||||
|
|||||||
@@ -643,7 +643,7 @@ export interface UpdateGroupRequest {
|
|||||||
// ==================== Account & Proxy Types ====================
|
// ==================== Account & Proxy Types ====================
|
||||||
|
|
||||||
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
|
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
|
||||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock'
|
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'service_account'
|
||||||
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
||||||
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
||||||
|
|
||||||
|
|||||||
@@ -1,93 +1,18 @@
|
|||||||
/**
|
/**
|
||||||
* Usage request scheduler — throttles Anthropic API calls by proxy exit.
|
* Usage request scheduler.
|
||||||
*
|
*
|
||||||
* Anthropic OAuth/setup-token accounts sharing the same proxy exit are placed
|
* All platforms execute immediately without queuing — the backend uses
|
||||||
* into a serial queue with a random 1–2s delay between requests, preventing
|
* passive sampling so upstream 429 rate-limit errors are no longer a concern.
|
||||||
* upstream 429 rate-limit errors.
|
|
||||||
*
|
|
||||||
* Proxy identity = host:port:username — two proxy records pointing to the
|
|
||||||
* same exit share a single queue. Accounts without a proxy go into a
|
|
||||||
* "direct" queue.
|
|
||||||
*
|
|
||||||
* All other platforms bypass the queue and execute immediately.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Account } from '@/types'
|
import type { Account } from '@/types'
|
||||||
|
|
||||||
const GROUP_DELAY_MIN_MS = 1000
|
|
||||||
const GROUP_DELAY_MAX_MS = 2000
|
|
||||||
|
|
||||||
type Task<T> = {
|
|
||||||
fn: () => Promise<T>
|
|
||||||
resolve: (value: T) => void
|
|
||||||
reject: (reason: unknown) => void
|
|
||||||
}
|
|
||||||
|
|
||||||
const queues = new Map<string, Task<unknown>[]>()
|
|
||||||
const running = new Set<string>()
|
|
||||||
|
|
||||||
/** Whether this account needs throttled queuing. */
|
|
||||||
function needsThrottle(account: Account): boolean {
|
|
||||||
return (
|
|
||||||
account.platform === 'anthropic' &&
|
|
||||||
(account.type === 'oauth' || account.type === 'setup-token')
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Build a queue key from proxy connection details. */
|
|
||||||
function buildGroupKey(account: Account): string {
|
|
||||||
const proxy = account.proxy
|
|
||||||
const proxyIdentity = proxy
|
|
||||||
? `${proxy.host}:${proxy.port}:${proxy.username || ''}`
|
|
||||||
: 'direct'
|
|
||||||
return `anthropic:${proxyIdentity}`
|
|
||||||
}
|
|
||||||
|
|
||||||
async function drain(groupKey: string) {
|
|
||||||
if (running.has(groupKey)) return
|
|
||||||
running.add(groupKey)
|
|
||||||
|
|
||||||
const queue = queues.get(groupKey)
|
|
||||||
while (queue && queue.length > 0) {
|
|
||||||
const task = queue.shift()!
|
|
||||||
try {
|
|
||||||
const result = await task.fn()
|
|
||||||
task.resolve(result)
|
|
||||||
} catch (err) {
|
|
||||||
task.reject(err)
|
|
||||||
}
|
|
||||||
if (queue.length > 0) {
|
|
||||||
const jitter = GROUP_DELAY_MIN_MS + Math.random() * (GROUP_DELAY_MAX_MS - GROUP_DELAY_MIN_MS)
|
|
||||||
await new Promise((r) => setTimeout(r, jitter))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
running.delete(groupKey)
|
|
||||||
queues.delete(groupKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Schedule a usage fetch. Anthropic accounts are queued by proxy exit;
|
* Schedule a usage fetch. All requests execute immediately.
|
||||||
* all other platforms execute immediately.
|
|
||||||
*/
|
*/
|
||||||
export function enqueueUsageRequest<T>(
|
export function enqueueUsageRequest<T>(
|
||||||
account: Account,
|
_account: Account,
|
||||||
fn: () => Promise<T>
|
fn: () => Promise<T>
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
// Non-Anthropic → fire immediately, no queuing
|
return fn()
|
||||||
if (!needsThrottle(account)) {
|
|
||||||
return fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
const key = buildGroupKey(account)
|
|
||||||
|
|
||||||
return new Promise<T>((resolve, reject) => {
|
|
||||||
let queue = queues.get(key)
|
|
||||||
if (!queue) {
|
|
||||||
queue = []
|
|
||||||
queues.set(key, queue)
|
|
||||||
}
|
|
||||||
queue.push({ fn, resolve, reject } as Task<unknown>)
|
|
||||||
drain(key)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,7 +141,17 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<template #table>
|
<template #table>
|
||||||
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @reset-status="handleBulkResetStatus" @refresh-token="handleBulkRefreshToken" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
<AccountBulkActionsBar
|
||||||
|
:selected-ids="selIds"
|
||||||
|
@delete="handleBulkDelete"
|
||||||
|
@reset-status="handleBulkResetStatus"
|
||||||
|
@refresh-token="handleBulkRefreshToken"
|
||||||
|
@edit-selected="openBulkEditSelected"
|
||||||
|
@edit-filtered="openBulkEditFiltered"
|
||||||
|
@clear="clearSelection"
|
||||||
|
@select-page="selectPage"
|
||||||
|
@toggle-schedulable="handleBulkToggleSchedulable"
|
||||||
|
/>
|
||||||
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
|
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||||
<DataTable
|
<DataTable
|
||||||
ref="dataTableRef"
|
ref="dataTableRef"
|
||||||
@@ -303,7 +313,17 @@
|
|||||||
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @schedule="handleSchedule" @reauth="handleReAuth" @refresh-token="handleRefresh" @recover-state="handleRecoverState" @reset-quota="handleResetQuota" @set-privacy="handleSetPrivacy" />
|
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @schedule="handleSchedule" @reauth="handleReAuth" @refresh-token="handleRefresh" @recover-state="handleRecoverState" @reset-quota="handleResetQuota" @set-privacy="handleSetPrivacy" />
|
||||||
<SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" />
|
<SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" />
|
||||||
<ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" />
|
<ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" />
|
||||||
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :selected-platforms="selPlatforms" :selected-types="selTypes" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" />
|
<BulkEditAccountModal
|
||||||
|
:show="showBulkEdit"
|
||||||
|
:account-ids="selIds"
|
||||||
|
:selected-platforms="selPlatforms"
|
||||||
|
:selected-types="selTypes"
|
||||||
|
:target="bulkEditTarget ?? undefined"
|
||||||
|
:proxies="proxies"
|
||||||
|
:groups="groups"
|
||||||
|
@close="showBulkEdit = false"
|
||||||
|
@updated="handleBulkUpdated"
|
||||||
|
/>
|
||||||
<TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" />
|
<TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" />
|
||||||
<ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" />
|
<ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" />
|
||||||
<ConfirmDialog :show="showExportDataDialog" :title="t('admin.accounts.dataExport')" :message="t('admin.accounts.dataExportConfirmMessage')" :confirm-text="t('admin.accounts.dataExportConfirm')" :cancel-text="t('common.cancel')" @confirm="handleExportData" @cancel="showExportDataDialog = false">
|
<ConfirmDialog :show="showExportDataDialog" :title="t('admin.accounts.dataExport')" :message="t('admin.accounts.dataExportConfirmMessage')" :confirm-text="t('admin.accounts.dataExportConfirm')" :cancel-text="t('common.cancel')" @confirm="handleExportData" @cancel="showExportDataDialog = false">
|
||||||
@@ -364,6 +384,29 @@ const proxies = ref<AccountProxy[]>([])
|
|||||||
const groups = ref<AdminGroup[]>([])
|
const groups = ref<AdminGroup[]>([])
|
||||||
const accountTableRef = ref<HTMLElement | null>(null)
|
const accountTableRef = ref<HTMLElement | null>(null)
|
||||||
const dataTableRef = ref<InstanceType<typeof DataTable> | null>(null)
|
const dataTableRef = ref<InstanceType<typeof DataTable> | null>(null)
|
||||||
|
type AccountBulkEditTarget =
|
||||||
|
| {
|
||||||
|
mode: 'selected'
|
||||||
|
accountIds: number[]
|
||||||
|
selectedPlatforms: AccountPlatform[]
|
||||||
|
selectedTypes: AccountType[]
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
mode: 'filtered'
|
||||||
|
filters: {
|
||||||
|
platform?: string
|
||||||
|
type?: string
|
||||||
|
status?: string
|
||||||
|
group?: string
|
||||||
|
search?: string
|
||||||
|
privacy_mode?: string
|
||||||
|
sort_by?: string
|
||||||
|
sort_order?: AccountSortOrder
|
||||||
|
}
|
||||||
|
previewCount: number
|
||||||
|
selectedPlatforms: AccountPlatform[]
|
||||||
|
selectedTypes: AccountType[]
|
||||||
|
}
|
||||||
const selPlatforms = computed<AccountPlatform[]>(() => {
|
const selPlatforms = computed<AccountPlatform[]>(() => {
|
||||||
const platforms = new Set(
|
const platforms = new Set(
|
||||||
accounts.value
|
accounts.value
|
||||||
@@ -387,6 +430,7 @@ const showImportData = ref(false)
|
|||||||
const showExportDataDialog = ref(false)
|
const showExportDataDialog = ref(false)
|
||||||
const includeProxyOnExport = ref(true)
|
const includeProxyOnExport = ref(true)
|
||||||
const showBulkEdit = ref(false)
|
const showBulkEdit = ref(false)
|
||||||
|
const bulkEditTarget = ref<AccountBulkEditTarget | null>(null)
|
||||||
const showTempUnsched = ref(false)
|
const showTempUnsched = ref(false)
|
||||||
const showDeleteDialog = ref(false)
|
const showDeleteDialog = ref(false)
|
||||||
const showReAuth = ref(false)
|
const showReAuth = ref(false)
|
||||||
@@ -1216,7 +1260,57 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
|||||||
appStore.showError(t('common.error'))
|
appStore.showError(t('common.error'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const handleBulkUpdated = () => { showBulkEdit.value = false; clearSelection(); reload() }
|
const buildBulkEditFilterSnapshot = () => {
|
||||||
|
const rawParams = toRaw(params) as Record<string, unknown>
|
||||||
|
const sortOrder: AccountSortOrder = rawParams.sort_order === 'desc' ? 'desc' : 'asc'
|
||||||
|
return {
|
||||||
|
platform: typeof rawParams.platform === 'string' ? rawParams.platform : '',
|
||||||
|
type: typeof rawParams.type === 'string' ? rawParams.type : '',
|
||||||
|
status: typeof rawParams.status === 'string' ? rawParams.status : '',
|
||||||
|
group: typeof rawParams.group === 'string' ? rawParams.group : '',
|
||||||
|
search: typeof rawParams.search === 'string' ? rawParams.search : '',
|
||||||
|
privacy_mode: typeof rawParams.privacy_mode === 'string' ? rawParams.privacy_mode : '',
|
||||||
|
sort_by: typeof rawParams.sort_by === 'string' ? rawParams.sort_by : '',
|
||||||
|
sort_order: sortOrder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const collectSelectionMetadata = (rows: Account[]) => {
|
||||||
|
const selectedPlatforms = Array.from(new Set(rows.map(account => account.platform)))
|
||||||
|
const selectedTypes = Array.from(new Set(rows.map(account => account.type)))
|
||||||
|
return { selectedPlatforms, selectedTypes }
|
||||||
|
}
|
||||||
|
|
||||||
|
const openBulkEditSelected = () => {
|
||||||
|
bulkEditTarget.value = {
|
||||||
|
mode: 'selected',
|
||||||
|
accountIds: [...selIds.value],
|
||||||
|
selectedPlatforms: [...selPlatforms.value],
|
||||||
|
selectedTypes: [...selTypes.value]
|
||||||
|
}
|
||||||
|
showBulkEdit.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const openBulkEditFiltered = async () => {
|
||||||
|
const filters = buildBulkEditFilterSnapshot()
|
||||||
|
const preview = await adminAPI.accounts.list(1, 100, filters)
|
||||||
|
const { selectedPlatforms, selectedTypes } = collectSelectionMetadata(preview.items)
|
||||||
|
bulkEditTarget.value = {
|
||||||
|
mode: 'filtered',
|
||||||
|
filters,
|
||||||
|
previewCount: preview.total,
|
||||||
|
selectedPlatforms,
|
||||||
|
selectedTypes
|
||||||
|
}
|
||||||
|
showBulkEdit.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBulkUpdated = () => {
|
||||||
|
showBulkEdit.value = false
|
||||||
|
bulkEditTarget.value = null
|
||||||
|
clearSelection()
|
||||||
|
reload()
|
||||||
|
}
|
||||||
const handleDataImported = () => { showImportData.value = false; reload() }
|
const handleDataImported = () => { showImportData.value = false; reload() }
|
||||||
const ACCOUNT_UNGROUPED_GROUP_QUERY_VALUE = 'ungrouped'
|
const ACCOUNT_UNGROUPED_GROUP_QUERY_VALUE = 'ungrouped'
|
||||||
const ACCOUNT_PRIVACY_MODE_UNSET_QUERY_VALUE = '__unset__'
|
const ACCOUNT_PRIVACY_MODE_UNSET_QUERY_VALUE = '__unset__'
|
||||||
|
|||||||
@@ -949,6 +949,285 @@
|
|||||||
</template>
|
</template>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- OpenAI Fast/Flex Policy Settings -->
|
||||||
|
<div class="card">
|
||||||
|
<div
|
||||||
|
class="border-b border-gray-100 px-6 py-4 dark:border-dark-700"
|
||||||
|
>
|
||||||
|
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.title") }}
|
||||||
|
</h2>
|
||||||
|
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.description") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-5 p-6">
|
||||||
|
<!-- Empty state -->
|
||||||
|
<div
|
||||||
|
v-if="openaiFastPolicyForm.rules.length === 0"
|
||||||
|
class="rounded-lg border border-dashed border-gray-200 p-6 text-center text-sm text-gray-500 dark:border-dark-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.empty") }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Rule Cards -->
|
||||||
|
<div
|
||||||
|
v-for="(rule, ruleIndex) in openaiFastPolicyForm.rules"
|
||||||
|
:key="ruleIndex"
|
||||||
|
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div class="mb-3 flex items-center justify-between">
|
||||||
|
<span
|
||||||
|
class="text-sm font-medium text-gray-900 dark:text-white"
|
||||||
|
>
|
||||||
|
{{
|
||||||
|
t("admin.settings.openaiFastPolicy.ruleHeader", {
|
||||||
|
index: ruleIndex + 1,
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="removeOpenAIFastPolicyRule(ruleIndex)"
|
||||||
|
class="rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||||
|
:title="t('admin.settings.openaiFastPolicy.removeRule')"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M6 18L18 6M6 6l12 12"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-1 gap-4 md:grid-cols-3">
|
||||||
|
<!-- Service Tier -->
|
||||||
|
<div>
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.serviceTier") }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.service_tier"
|
||||||
|
@update:modelValue="
|
||||||
|
rule.service_tier = $event as
|
||||||
|
| 'all'
|
||||||
|
| 'priority'
|
||||||
|
| 'flex'
|
||||||
|
"
|
||||||
|
:options="openaiFastPolicyTierOptions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Action -->
|
||||||
|
<div>
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.action") }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.action"
|
||||||
|
@update:modelValue="
|
||||||
|
rule.action = $event as 'pass' | 'filter' | 'block'
|
||||||
|
"
|
||||||
|
:options="openaiFastPolicyActionOptions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Scope -->
|
||||||
|
<div>
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.scope") }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.scope"
|
||||||
|
@update:modelValue="
|
||||||
|
rule.scope = $event as
|
||||||
|
| 'all'
|
||||||
|
| 'oauth'
|
||||||
|
| 'apikey'
|
||||||
|
| 'bedrock'
|
||||||
|
"
|
||||||
|
:options="openaiFastPolicyScopeOptions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Error Message (only when action=block) -->
|
||||||
|
<div v-if="rule.action === 'block'" class="mt-3">
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.errorMessage") }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="rule.error_message"
|
||||||
|
type="text"
|
||||||
|
class="input"
|
||||||
|
:placeholder="
|
||||||
|
t(
|
||||||
|
'admin.settings.openaiFastPolicy.errorMessagePlaceholder',
|
||||||
|
)
|
||||||
|
"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.errorMessageHint") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Model Whitelist -->
|
||||||
|
<div class="mt-3">
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.modelWhitelist") }}
|
||||||
|
</label>
|
||||||
|
<p class="mb-2 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{{
|
||||||
|
t("admin.settings.openaiFastPolicy.modelWhitelistHint")
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
<div
|
||||||
|
v-for="(_, patternIdx) in rule.model_whitelist || []"
|
||||||
|
:key="patternIdx"
|
||||||
|
class="mb-1.5 flex items-center gap-2"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
v-model="rule.model_whitelist![patternIdx]"
|
||||||
|
type="text"
|
||||||
|
class="input input-sm flex-1"
|
||||||
|
:placeholder="
|
||||||
|
t(
|
||||||
|
'admin.settings.openaiFastPolicy.modelPatternPlaceholder',
|
||||||
|
)
|
||||||
|
"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="
|
||||||
|
removeOpenAIFastPolicyModelPattern(rule, patternIdx)
|
||||||
|
"
|
||||||
|
class="shrink-0 rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M6 18L18 6M6 6l12 12"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="addOpenAIFastPolicyModelPattern(rule)"
|
||||||
|
class="mb-2 inline-flex items-center gap-1 text-xs text-primary-600 transition-colors hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-3.5 w-3.5"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M12 4v16m8-8H4"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.addModelPattern") }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Fallback Action (only when model_whitelist is non-empty) -->
|
||||||
|
<div
|
||||||
|
v-if="
|
||||||
|
rule.model_whitelist && rule.model_whitelist.length > 0
|
||||||
|
"
|
||||||
|
class="mt-3"
|
||||||
|
>
|
||||||
|
<label
|
||||||
|
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.fallbackAction") }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.fallback_action || 'pass'"
|
||||||
|
@update:modelValue="
|
||||||
|
rule.fallback_action = $event as
|
||||||
|
| 'pass'
|
||||||
|
| 'filter'
|
||||||
|
| 'block'
|
||||||
|
"
|
||||||
|
:options="openaiFastPolicyActionOptions"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{{
|
||||||
|
t("admin.settings.openaiFastPolicy.fallbackActionHint")
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
<div v-if="rule.fallback_action === 'block'" class="mt-2">
|
||||||
|
<input
|
||||||
|
v-model="rule.fallback_error_message"
|
||||||
|
type="text"
|
||||||
|
class="input"
|
||||||
|
:placeholder="
|
||||||
|
t(
|
||||||
|
'admin.settings.openaiFastPolicy.fallbackErrorMessagePlaceholder',
|
||||||
|
)
|
||||||
|
"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Add Rule Button -->
|
||||||
|
<div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="addOpenAIFastPolicyRule"
|
||||||
|
class="btn btn-secondary btn-sm inline-flex items-center gap-1"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M12 4v16m8-8H4"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.addRule") }}
|
||||||
|
</button>
|
||||||
|
<p class="mt-2 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{{ t("admin.settings.openaiFastPolicy.saveHint") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<!-- /Tab: Gateway -->
|
<!-- /Tab: Gateway -->
|
||||||
|
|
||||||
@@ -5199,6 +5478,7 @@ import type {
|
|||||||
SystemSettings,
|
SystemSettings,
|
||||||
UpdateSettingsRequest,
|
UpdateSettingsRequest,
|
||||||
DefaultSubscriptionSetting,
|
DefaultSubscriptionSetting,
|
||||||
|
OpenAIFastPolicyRule,
|
||||||
WeChatConnectMode,
|
WeChatConnectMode,
|
||||||
WebSearchEmulationConfig,
|
WebSearchEmulationConfig,
|
||||||
WebSearchProviderConfig,
|
WebSearchProviderConfig,
|
||||||
@@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({
|
|||||||
}>,
|
}>,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// OpenAI Fast/Flex Policy 状态
|
||||||
|
const openaiFastPolicyForm = reactive({
|
||||||
|
rules: [] as OpenAIFastPolicyRule[],
|
||||||
|
});
|
||||||
|
// 标记 openai_fast_policy_settings 是否已成功从后端加载,
|
||||||
|
// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。
|
||||||
|
const openaiFastPolicyLoaded = ref(false);
|
||||||
|
|
||||||
const tablePageSizeMin = 5;
|
const tablePageSizeMin = 5;
|
||||||
const tablePageSizeMax = 1000;
|
const tablePageSizeMax = 1000;
|
||||||
const tablePageSizeDefault = 20;
|
const tablePageSizeDefault = 20;
|
||||||
@@ -6116,6 +6404,23 @@ async function loadSettings() {
|
|||||||
);
|
);
|
||||||
form.oidc_connect_client_secret = "";
|
form.oidc_connect_client_secret = "";
|
||||||
|
|
||||||
|
// Load OpenAI fast/flex policy rules from bulk settings.
|
||||||
|
// 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值,
|
||||||
|
// 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。
|
||||||
|
if (
|
||||||
|
settings.openai_fast_policy_settings &&
|
||||||
|
Array.isArray(settings.openai_fast_policy_settings.rules)
|
||||||
|
) {
|
||||||
|
openaiFastPolicyForm.rules =
|
||||||
|
settings.openai_fast_policy_settings.rules.map((rule) => ({
|
||||||
|
...rule,
|
||||||
|
model_whitelist: rule.model_whitelist
|
||||||
|
? [...rule.model_whitelist]
|
||||||
|
: [],
|
||||||
|
}));
|
||||||
|
openaiFastPolicyLoaded.value = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Load web search emulation config separately
|
// Load web search emulation config separately
|
||||||
await loadWebSearchConfig();
|
await loadWebSearchConfig();
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
@@ -6460,10 +6765,39 @@ async function saveSettings() {
|
|||||||
affiliate_enabled: form.affiliate_enabled,
|
affiliate_enabled: form.affiliate_enabled,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 仅当 openai_fast_policy_settings 已成功从后端加载时才回写,
|
||||||
|
// 否则省略整个字段,让后端保留既有规则(含默认值)。
|
||||||
|
if (openaiFastPolicyLoaded.value) {
|
||||||
|
payload.openai_fast_policy_settings = {
|
||||||
|
rules: openaiFastPolicyForm.rules.map((rule) => {
|
||||||
|
const whitelist = (rule.model_whitelist || [])
|
||||||
|
.map((p) => p.trim())
|
||||||
|
.filter((p) => p !== "");
|
||||||
|
const hasWhitelist = whitelist.length > 0;
|
||||||
|
return {
|
||||||
|
service_tier: rule.service_tier,
|
||||||
|
action: rule.action,
|
||||||
|
scope: rule.scope,
|
||||||
|
error_message:
|
||||||
|
rule.action === "block" ? rule.error_message : undefined,
|
||||||
|
model_whitelist: hasWhitelist ? whitelist : undefined,
|
||||||
|
fallback_action: hasWhitelist
|
||||||
|
? rule.fallback_action || "pass"
|
||||||
|
: undefined,
|
||||||
|
fallback_error_message:
|
||||||
|
hasWhitelist && rule.fallback_action === "block"
|
||||||
|
? rule.fallback_error_message
|
||||||
|
: undefined,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
|
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
|
||||||
|
|
||||||
const updated = await adminAPI.settings.updateSettings(payload);
|
const updated = await adminAPI.settings.updateSettings(payload);
|
||||||
for (const [key, value] of Object.entries(updated)) {
|
for (const [key, value] of Object.entries(updated)) {
|
||||||
|
if (key === "openai_fast_policy_settings") continue;
|
||||||
if (value !== null && value !== undefined) {
|
if (value !== null && value !== undefined) {
|
||||||
(form as Record<string, unknown>)[key] = value;
|
(form as Record<string, unknown>)[key] = value;
|
||||||
}
|
}
|
||||||
@@ -6507,6 +6841,20 @@ async function saveSettings() {
|
|||||||
form.wechat_connect_mode,
|
form.wechat_connect_mode,
|
||||||
);
|
);
|
||||||
form.oidc_connect_client_secret = "";
|
form.oidc_connect_client_secret = "";
|
||||||
|
// Refresh OpenAI fast/flex policy from server response
|
||||||
|
if (
|
||||||
|
updated.openai_fast_policy_settings &&
|
||||||
|
Array.isArray(updated.openai_fast_policy_settings.rules)
|
||||||
|
) {
|
||||||
|
openaiFastPolicyForm.rules =
|
||||||
|
updated.openai_fast_policy_settings.rules.map((rule) => ({
|
||||||
|
...rule,
|
||||||
|
model_whitelist: rule.model_whitelist
|
||||||
|
? [...rule.model_whitelist]
|
||||||
|
: [],
|
||||||
|
}));
|
||||||
|
openaiFastPolicyLoaded.value = true;
|
||||||
|
}
|
||||||
// Save web search emulation config separately (errors handled internally)
|
// Save web search emulation config separately (errors handled internally)
|
||||||
const wsOk = await saveWebSearchConfig();
|
const wsOk = await saveWebSearchConfig();
|
||||||
// Refresh cached settings so sidebar/header update immediately
|
// Refresh cached settings so sidebar/header update immediately
|
||||||
@@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== OpenAI Fast/Flex Policy ====================
|
||||||
|
|
||||||
|
const openaiFastPolicyTierOptions = computed(() => [
|
||||||
|
{ value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") },
|
||||||
|
{
|
||||||
|
value: "priority",
|
||||||
|
label: t("admin.settings.openaiFastPolicy.tierPriority"),
|
||||||
|
},
|
||||||
|
{ value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const openaiFastPolicyActionOptions = computed(() => [
|
||||||
|
{ value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") },
|
||||||
|
{ value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") },
|
||||||
|
{ value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const openaiFastPolicyScopeOptions = computed(() => [
|
||||||
|
{ value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") },
|
||||||
|
{ value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") },
|
||||||
|
{ value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") },
|
||||||
|
{
|
||||||
|
value: "bedrock",
|
||||||
|
label: t("admin.settings.openaiFastPolicy.scopeBedrock"),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
function addOpenAIFastPolicyRule() {
|
||||||
|
openaiFastPolicyForm.rules.push({
|
||||||
|
service_tier: "priority",
|
||||||
|
action: "filter",
|
||||||
|
scope: "all",
|
||||||
|
error_message: "",
|
||||||
|
model_whitelist: [],
|
||||||
|
fallback_action: "pass",
|
||||||
|
fallback_error_message: "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeOpenAIFastPolicyRule(index: number) {
|
||||||
|
openaiFastPolicyForm.rules.splice(index, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) {
|
||||||
|
if (!rule.model_whitelist) rule.model_whitelist = [];
|
||||||
|
rule.model_whitelist.push("");
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeOpenAIFastPolicyModelPattern(
|
||||||
|
rule: OpenAIFastPolicyRule,
|
||||||
|
idx: number,
|
||||||
|
) {
|
||||||
|
rule.model_whitelist?.splice(idx, 1);
|
||||||
|
}
|
||||||
|
|
||||||
async function saveBetaPolicySettings() {
|
async function saveBetaPolicySettings() {
|
||||||
betaPolicySaving.value = true;
|
betaPolicySaving.value = true;
|
||||||
try {
|
try {
|
||||||
|
|||||||
152
frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts
Normal file
152
frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { flushPromises, mount } from '@vue/test-utils'
|
||||||
|
|
||||||
|
import AccountsView from '../AccountsView.vue'
|
||||||
|
|
||||||
|
const {
|
||||||
|
listAccounts,
|
||||||
|
listWithEtag,
|
||||||
|
getBatchTodayStats,
|
||||||
|
getAllProxies,
|
||||||
|
getAllGroups
|
||||||
|
} = vi.hoisted(() => ({
|
||||||
|
listAccounts: vi.fn(),
|
||||||
|
listWithEtag: vi.fn(),
|
||||||
|
getBatchTodayStats: vi.fn(),
|
||||||
|
getAllProxies: vi.fn(),
|
||||||
|
getAllGroups: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api/admin', () => ({
|
||||||
|
adminAPI: {
|
||||||
|
accounts: {
|
||||||
|
list: listAccounts,
|
||||||
|
listWithEtag,
|
||||||
|
getBatchTodayStats,
|
||||||
|
delete: vi.fn(),
|
||||||
|
batchClearError: vi.fn(),
|
||||||
|
batchRefresh: vi.fn(),
|
||||||
|
toggleSchedulable: vi.fn()
|
||||||
|
},
|
||||||
|
proxies: {
|
||||||
|
getAll: getAllProxies
|
||||||
|
},
|
||||||
|
groups: {
|
||||||
|
getAll: getAllGroups
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/stores/app', () => ({
|
||||||
|
useAppStore: () => ({
|
||||||
|
showError: vi.fn(),
|
||||||
|
showSuccess: vi.fn(),
|
||||||
|
showInfo: vi.fn()
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/stores/auth', () => ({
|
||||||
|
useAuthStore: () => ({
|
||||||
|
token: 'test-token'
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => key
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const DataTableStub = {
|
||||||
|
props: ['columns', 'data'],
|
||||||
|
template: '<div data-test="data-table"></div>'
|
||||||
|
}
|
||||||
|
|
||||||
|
const AccountBulkActionsBarStub = {
|
||||||
|
props: ['selectedIds'],
|
||||||
|
emits: ['edit-filtered'],
|
||||||
|
template: '<button data-test="edit-filtered" @click="$emit(\'edit-filtered\')">edit filtered</button>'
|
||||||
|
}
|
||||||
|
|
||||||
|
const BulkEditAccountModalStub = {
|
||||||
|
props: ['show', 'target'],
|
||||||
|
template: '<div data-test="bulk-edit-modal" :data-show="String(show)" :data-target-mode="target?.mode ?? \'\'"></div>'
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('admin AccountsView bulk edit scope', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
localStorage.clear()
|
||||||
|
|
||||||
|
listAccounts.mockReset()
|
||||||
|
listWithEtag.mockReset()
|
||||||
|
getBatchTodayStats.mockReset()
|
||||||
|
getAllProxies.mockReset()
|
||||||
|
getAllGroups.mockReset()
|
||||||
|
|
||||||
|
listAccounts.mockResolvedValue({
|
||||||
|
items: [],
|
||||||
|
total: 0,
|
||||||
|
page: 1,
|
||||||
|
page_size: 20,
|
||||||
|
pages: 0
|
||||||
|
})
|
||||||
|
listWithEtag.mockResolvedValue({
|
||||||
|
notModified: true,
|
||||||
|
etag: null,
|
||||||
|
data: null
|
||||||
|
})
|
||||||
|
getBatchTodayStats.mockResolvedValue({ stats: {} })
|
||||||
|
getAllProxies.mockResolvedValue([])
|
||||||
|
getAllGroups.mockResolvedValue([])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('opens bulk edit in filtered-results mode from the bulk actions dropdown', async () => {
|
||||||
|
const wrapper = mount(AccountsView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
AppLayout: { template: '<div><slot /></div>' },
|
||||||
|
TablePageLayout: {
|
||||||
|
template: '<div><slot name="filters" /><slot name="table" /><slot name="pagination" /></div>'
|
||||||
|
},
|
||||||
|
DataTable: DataTableStub,
|
||||||
|
Pagination: true,
|
||||||
|
ConfirmDialog: true,
|
||||||
|
AccountTableActions: { template: '<div><slot name="beforeCreate" /><slot name="after" /></div>' },
|
||||||
|
AccountTableFilters: { template: '<div></div>' },
|
||||||
|
AccountBulkActionsBar: AccountBulkActionsBarStub,
|
||||||
|
AccountActionMenu: true,
|
||||||
|
ImportDataModal: true,
|
||||||
|
ReAuthAccountModal: true,
|
||||||
|
AccountTestModal: true,
|
||||||
|
AccountStatsModal: true,
|
||||||
|
ScheduledTestsPanel: true,
|
||||||
|
SyncFromCrsModal: true,
|
||||||
|
TempUnschedStatusModal: true,
|
||||||
|
ErrorPassthroughRulesModal: true,
|
||||||
|
TLSFingerprintProfilesModal: true,
|
||||||
|
CreateAccountModal: true,
|
||||||
|
EditAccountModal: true,
|
||||||
|
BulkEditAccountModal: BulkEditAccountModalStub,
|
||||||
|
PlatformTypeBadge: true,
|
||||||
|
AccountCapacityCell: true,
|
||||||
|
AccountStatusIndicator: true,
|
||||||
|
AccountTodayStatsCell: true,
|
||||||
|
AccountGroupsCell: true,
|
||||||
|
AccountUsageCell: true,
|
||||||
|
Icon: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
await flushPromises()
|
||||||
|
await wrapper.get('[data-test="edit-filtered"]').trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(wrapper.get('[data-test="bulk-edit-modal"]').attributes('data-show')).toBe('true')
|
||||||
|
expect(wrapper.get('[data-test="bulk-edit-modal"]').attributes('data-target-mode')).toBe('filtered')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -136,13 +136,13 @@ const validation = computed(() => {
|
|||||||
// 验证高级设置
|
// 验证高级设置
|
||||||
if (advancedSettings.value) {
|
if (advancedSettings.value) {
|
||||||
const { error_log_retention_days, minute_metrics_retention_days, hourly_metrics_retention_days } = advancedSettings.value.data_retention
|
const { error_log_retention_days, minute_metrics_retention_days, hourly_metrics_retention_days } = advancedSettings.value.data_retention
|
||||||
if (error_log_retention_days < 1 || error_log_retention_days > 365) {
|
if (error_log_retention_days < 0 || error_log_retention_days > 365) {
|
||||||
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
||||||
}
|
}
|
||||||
if (minute_metrics_retention_days < 1 || minute_metrics_retention_days > 365) {
|
if (minute_metrics_retention_days < 0 || minute_metrics_retention_days > 365) {
|
||||||
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
||||||
}
|
}
|
||||||
if (hourly_metrics_retention_days < 1 || hourly_metrics_retention_days > 365) {
|
if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) {
|
||||||
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,7 +431,7 @@ async function saveAllSettings() {
|
|||||||
<input
|
<input
|
||||||
v-model.number="advancedSettings.data_retention.error_log_retention_days"
|
v-model.number="advancedSettings.data_retention.error_log_retention_days"
|
||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="0"
|
||||||
max="365"
|
max="365"
|
||||||
class="input"
|
class="input"
|
||||||
/>
|
/>
|
||||||
@@ -441,7 +441,7 @@ async function saveAllSettings() {
|
|||||||
<input
|
<input
|
||||||
v-model.number="advancedSettings.data_retention.minute_metrics_retention_days"
|
v-model.number="advancedSettings.data_retention.minute_metrics_retention_days"
|
||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="0"
|
||||||
max="365"
|
max="365"
|
||||||
class="input"
|
class="input"
|
||||||
/>
|
/>
|
||||||
@@ -451,7 +451,7 @@ async function saveAllSettings() {
|
|||||||
<input
|
<input
|
||||||
v-model.number="advancedSettings.data_retention.hourly_metrics_retention_days"
|
v-model.number="advancedSettings.data_retention.hourly_metrics_retention_days"
|
||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="0"
|
||||||
max="365"
|
max="365"
|
||||||
class="input"
|
class="input"
|
||||||
/>
|
/>
|
||||||
|
|||||||
Reference in New Issue
Block a user