mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df722c9a6e | ||
|
|
d9e68f2ca1 | ||
|
|
c129825f9b | ||
|
|
ff50b8b6ea | ||
|
|
4cbf518f0a | ||
|
|
dc09b367dc | ||
|
|
11fe29223d | ||
|
|
0b84d12dbb | ||
|
|
76e2503d5e | ||
|
|
3ab40269b4 | ||
|
|
650ddb2e39 | ||
|
|
0a914e034c | ||
|
|
6a41cf6a51 | ||
|
|
23555be380 | ||
|
|
47fb38bca1 | ||
|
|
72d5ee4cd1 | ||
|
|
b2bdba78dd | ||
|
|
3930bebaf9 | ||
|
|
e736de1ed9 | ||
|
|
57099a6af6 | ||
|
|
adf01ac880 | ||
|
|
4d145300c3 | ||
|
|
4e4cc80971 | ||
|
|
48912014a1 | ||
|
|
9d801595c9 | ||
|
|
9c448f89a8 | ||
|
|
73b872998e | ||
|
|
094e1171ef | ||
|
|
733627cf9d | ||
|
|
f084d30d65 | ||
|
|
3953dc9ce4 | ||
|
|
8ad099baa6 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -122,7 +122,7 @@ scripts
|
||||
.code-review-state
|
||||
#openspec/
|
||||
code-reviews/
|
||||
#AGENTS.md
|
||||
AGENTS.md
|
||||
backend/cmd/server/server
|
||||
deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.119
|
||||
0.1.122
|
||||
|
||||
@@ -528,6 +528,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
// 捕获闭包内创建的账号引用,用于创建成功后触发异步探测。
|
||||
// 幂等重放时闭包不会执行 → createdAccount 为 nil → 不重复调度。
|
||||
var createdAccount *service.Account
|
||||
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
@@ -549,6 +553,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
createdAccount = account
|
||||
// Antigravity OAuth: 新账号直接设置隐私
|
||||
h.adminService.ForceAntigravityPrivacy(ctx, account)
|
||||
// OpenAI OAuth: 新账号直接设置隐私
|
||||
@@ -577,6 +582,9 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
// OpenAI APIKey 账号创建后异步探测上游 /v1/responses 能力。
|
||||
// 探测失败不影响账号创建响应。
|
||||
h.scheduleOpenAIResponsesProbe(createdAccount)
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
@@ -637,9 +645,39 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAI APIKey: credentials 修改后重新探测上游能力(base_url/api_key 可能变更)。
|
||||
// 异步执行,探测失败不影响账号更新响应。
|
||||
if len(req.Credentials) > 0 {
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// scheduleOpenAIResponsesProbe 异步触发 OpenAI APIKey 账号的 Responses API 能力探测。
|
||||
//
|
||||
// 仅对 platform=openai && type=apikey 账号生效;其他账号无操作。
|
||||
// 探测本身在 goroutine 中执行(会发一次 HTTP 请求到上游),不会阻塞
|
||||
// 当前请求。探测错误仅记录日志,不向上下文传播:探测失败时标记保持缺失,
|
||||
// 网关会按"现状即证据"默认走 Responses。
|
||||
func (h *AccountHandler) scheduleOpenAIResponsesProbe(account *service.Account) {
|
||||
if account == nil || account.Platform != service.PlatformOpenAI || account.Type != service.AccountTypeAPIKey {
|
||||
return
|
||||
}
|
||||
if h.accountTestService == nil {
|
||||
return
|
||||
}
|
||||
accountID := account.ID
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("openai_responses_probe_panic", "account_id", accountID, "recover", r)
|
||||
}
|
||||
}()
|
||||
h.accountTestService.ProbeOpenAIAPIKeyResponsesSupport(context.Background(), accountID)
|
||||
}()
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
// DELETE /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
@@ -1231,6 +1269,8 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
openaiPrivacyAccounts = append(openaiPrivacyAccounts, account)
|
||||
}
|
||||
}
|
||||
// OpenAI APIKey 账号异步探测 /v1/responses 能力。
|
||||
h.scheduleOpenAIResponsesProbe(account)
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
|
||||
@@ -2,8 +2,11 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetUserOverview returns one user's affiliate overview.
|
||||
// GET /api/v1/admin/affiliates/users/:user_id/overview
|
||||
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, overview)
|
||||
}
|
||||
|
||||
// ListInviteRecords returns all inviter-invitee relationships.
|
||||
// GET /api/v1/admin/affiliates/invites
|
||||
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListRebateRecords returns all order-level affiliate rebate records.
|
||||
// GET /api/v1/admin/affiliates/rebates
|
||||
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
|
||||
// GET /api/v1/admin/affiliates/transfers
|
||||
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := parseAffiliateRecordFilter(c, page, pageSize)
|
||||
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, total, filter.Page, filter.PageSize)
|
||||
}
|
||||
|
||||
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
|
||||
filter := service.AffiliateRecordFilter{
|
||||
Search: c.Query("search"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
SortBy: c.Query("sort_by"),
|
||||
SortDesc: c.Query("sort_order") != "asc",
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
userTZ := c.Query("timezone")
|
||||
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
|
||||
filter.StartAt = t
|
||||
}
|
||||
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
|
||||
filter.EndAt = t
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &parsed
|
||||
}
|
||||
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
|
||||
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
|
||||
return &end
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -209,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: settings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||
@@ -441,9 +442,10 @@ type UpdateSettingsRequest struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Payment visible method routing
|
||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||
@@ -1273,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.EnableCCHSigning
|
||||
}(),
|
||||
EnableAnthropicCacheTTL1hInjection: func() bool {
|
||||
if req.EnableAnthropicCacheTTL1hInjection != nil {
|
||||
return *req.EnableAnthropicCacheTTL1hInjection
|
||||
}
|
||||
return previousSettings.EnableAnthropicCacheTTL1hInjection
|
||||
}(),
|
||||
PaymentVisibleMethodAlipaySource: func() string {
|
||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||
@@ -1570,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||
@@ -1949,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EnableCCHSigning != after.EnableCCHSigning {
|
||||
changed = append(changed, "enable_cch_signing")
|
||||
}
|
||||
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
|
||||
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
|
||||
}
|
||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||
changed = append(changed, "payment_visible_method_alipay_source")
|
||||
}
|
||||
|
||||
@@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
// GetBalanceHistory handles getting user's balance/concurrency change history
|
||||
// GET /api/v1/admin/users/:id/balance-history
|
||||
// Query params:
|
||||
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
// - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,9 +142,10 @@ type SystemSettings struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||
|
||||
@@ -262,6 +262,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// [DEBUG-STICKY] 打印会话 hash 生成结果
|
||||
reqLog.Info("sticky.session_hash_generated",
|
||||
zap.String("session_hash", sessionHash),
|
||||
zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
|
||||
)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
platform := ""
|
||||
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
@@ -278,6 +284,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
// [DEBUG-STICKY] 打印粘性会话查询结果
|
||||
reqLog.Info("sticky.cache_lookup",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("bound_account_id", sessionBoundAccountID),
|
||||
)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
@@ -286,6 +297,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
} else {
|
||||
reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
@@ -536,6 +549,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
reqLog.Info("sticky.selecting_account",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("has_bound_session", hasBoundSession),
|
||||
zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
|
||||
)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
@@ -569,6 +588,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// [DEBUG-STICKY] 打印账号选择结果
|
||||
reqLog.Info("sticky.account_selected",
|
||||
zap.Int64("selected_account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.Bool("slot_acquired", selection.Acquired),
|
||||
zap.Bool("has_wait_plan", selection.WaitPlan != nil),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
|
||||
)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
@@ -635,6 +664,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
reqLog.Info("sticky.bind_after_wait",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("account_id", account.ID),
|
||||
)
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
@@ -829,6 +862,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定粘性会话(成功转发后绑定/刷新)
|
||||
// - 无现有绑定(首次请求):创建绑定
|
||||
// - 选中账号与粘性账号一致:刷新 TTL
|
||||
// - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
|
||||
// 下次请求粘性账号恢复后仍可命中
|
||||
if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -120,7 +121,6 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
@@ -138,32 +138,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
false,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
@@ -191,12 +167,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@@ -276,7 +251,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
@@ -299,3 +274,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
|
||||
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
|
||||
// has been probed to not support the Responses API, the request is
|
||||
// forwarded directly to /v1/chat/completions — not through the default
|
||||
// CC→Responses conversion path.
|
||||
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
|
||||
if account != nil && account.Type == service.AccountTypeAPIKey &&
|
||||
!openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return "/v1/chat/completions"
|
||||
}
|
||||
return GetUpstreamEndpoint(c, account.Platform)
|
||||
}
|
||||
|
||||
@@ -37,16 +37,6 @@ type OpenAIGatewayHandler struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
|
||||
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
|
||||
return fallbackModel
|
||||
}
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
@@ -1233,6 +1223,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -352,30 +353,6 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
|
||||
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
|
||||
})
|
||||
|
||||
t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
|
||||
})
|
||||
|
||||
t.Run("returns_empty_without_group_default", func(t *testing.T) {
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
|
||||
Group: &service.Group{},
|
||||
}, ""))
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
|
||||
t.Run("exact_claude_model_override_wins", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
@@ -651,6 +628,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||
})
|
||||
|
||||
require.NotNil(t, got.log.UserAgent)
|
||||
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||
require.True(t, got.log.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||
channelMapping: map[string]string{
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||
"上游首帧应使用渠道映射后的模型")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||
userAgent: testStringPtr(""),
|
||||
})
|
||||
|
||||
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -796,3 +813,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogCase struct {
|
||||
firstPayload string
|
||||
userAgent *string
|
||||
channelMapping map[string]string
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogResult struct {
|
||||
log *service.UsageLog
|
||||
upstreamFirstPayload []byte
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
if s.account.Platform != platform {
|
||||
return nil, nil
|
||||
}
|
||||
return []service.Account{s.account}, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID != id {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.account
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if s.created != nil {
|
||||
s.created <- log
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||
service.ChannelRepository
|
||||
channels []service.Channel
|
||||
groupPlatforms map[int64]string
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
return s.channels, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
out := make(map[int64]string, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||
out[groupID] = platform
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstreamPayloadCh := make(chan []byte, 1)
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
upstreamErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr != nil {
|
||||
upstreamErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||
return
|
||||
}
|
||||
upstreamPayloadCh <- payload
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
))
|
||||
cancelWrite()
|
||||
if writeErr != nil {
|
||||
upstreamErrCh <- writeErr
|
||||
return
|
||||
}
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
upstreamErrCh <- nil
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
groupID := int64(4201)
|
||||
account := service.Account{
|
||||
ID: 9901,
|
||||
Name: "openai-ws-passthrough-usage-e2e",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": upstreamServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||
|
||||
var channelSvc *service.ChannelService
|
||||
if len(tc.channelMapping) > 0 {
|
||||
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||
channels: []service.Channel{{
|
||||
ID: 7701,
|
||||
Name: "openai-ws-e2e-channel",
|
||||
Status: service.StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||
}},
|
||||
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
nil,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1801,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
if tc.userAgent != nil {
|
||||
headers.Set("User-Agent", *tc.userAgent)
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
var usageLog *service.UsageLog
|
||||
select {
|
||||
case usageLog = <-usageRepo.created:
|
||||
require.NotNil(t, usageLog)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||
}
|
||||
|
||||
var upstreamFirstPayload []byte
|
||||
select {
|
||||
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case upstreamErr := <-upstreamErrCh:
|
||||
require.NoError(t, upstreamErr)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 结束超时")
|
||||
}
|
||||
|
||||
return openAIResponsesWSUsageLogResult{
|
||||
log: usageLog,
|
||||
upstreamFirstPayload: upstreamFirstPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func testStringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||
assert.Equal(t, 12, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 4, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
|
||||
@@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.done",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
@@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
|
||||
return resToAnthHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToAnthHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
|
||||
type ResponsesStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// response.created / response.completed / response.failed / response.incomplete
|
||||
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
|
||||
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
75
backend/internal/pkg/openai_compat/upstream_capability.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Package openai_compat 提供 OpenAI 协议族在不同上游间的能力差异判定工具。
|
||||
//
|
||||
// 背景:sub2api 的 OpenAI APIKey 账号通过 base_url 接入多种第三方 OpenAI 兼容上游
|
||||
// (DeepSeek、Kimi、GLM、Qwen 等)。这些上游普遍只支持 /v1/chat/completions,
|
||||
// 不存在 /v1/responses 端点。但网关历史代码无差别走 CC→Responses 转换并打到
|
||||
// /v1/responses,导致兼容上游 404。
|
||||
//
|
||||
// 本包提供基于"账号探测标记"的能力判定,配合
|
||||
// internal/service/openai_apikey_responses_probe.go 在创建/修改账号时一次性
|
||||
// 探测并落标。
|
||||
//
|
||||
// 设计取舍:
|
||||
// - 不维护静态 host 白名单——避免新增厂商时必须改代码(讨论沉淀于
|
||||
// pensieve/short-term/knowledge/upstream-capability-detection-design-tradeoffs)
|
||||
// - 标记缺失时默认 true(即"走 Responses"),保持与重构前老代码完全一致的存量
|
||||
// 账号行为("现状即证据"原则;详见
|
||||
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
|
||||
package openai_compat
|
||||
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
|
||||
//
|
||||
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
|
||||
type AccountResponsesSupport int
|
||||
|
||||
const (
|
||||
// ResponsesSupportUnknown 表示账号尚未完成能力探测(extra 字段缺失)。
|
||||
// 上游路由层应按"现状即证据"原则默认走 Responses,保持与重构前一致。
|
||||
ResponsesSupportUnknown AccountResponsesSupport = iota
|
||||
|
||||
// ResponsesSupportYes 探测确认上游支持 /v1/responses。
|
||||
ResponsesSupportYes
|
||||
|
||||
// ResponsesSupportNo 探测确认上游不支持 /v1/responses,应走
|
||||
// /v1/chat/completions 直转路径。
|
||||
ResponsesSupportNo
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
|
||||
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
|
||||
const ExtraKeyResponsesSupported = "openai_responses_supported"
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
|
||||
//
|
||||
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
|
||||
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
|
||||
func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
|
||||
if extra == nil {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
v, ok := extra[ExtraKeyResponsesSupported]
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
supported, ok := v.(bool)
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
if supported {
|
||||
return ResponsesSupportYes
|
||||
}
|
||||
return ResponsesSupportNo
|
||||
}
|
||||
|
||||
// ShouldUseResponsesAPI 判断 OpenAI APIKey 账号的入站 /v1/chat/completions 请求
|
||||
// 是否应走"CC→Responses 转换 + 上游 /v1/responses"路径。
|
||||
//
|
||||
// 返回 true 的两种情况:
|
||||
// 1. 账号已探测确认支持 Responses
|
||||
// 2. 账号未探测(标记缺失)——按"现状即证据"原则保留旧行为
|
||||
//
|
||||
// 仅当账号已探测且确认不支持时返回 false,此时调用方应走 CC 直转路径
|
||||
// (详见 internal/service/openai_gateway_chat_completions_raw.go)。
|
||||
func ShouldUseResponsesAPI(extra map[string]any) bool {
|
||||
return ResolveResponsesSupport(extra) != ResponsesSupportNo
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package openai_compat
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveResponsesSupport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want AccountResponsesSupport
|
||||
}{
|
||||
{"nil extra", nil, ResponsesSupportUnknown},
|
||||
{"empty extra", map[string]any{}, ResponsesSupportUnknown},
|
||||
{"key missing", map[string]any{"other": "value"}, ResponsesSupportUnknown},
|
||||
{"value true", map[string]any{ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
|
||||
{"value false", map[string]any{ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
|
||||
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
|
||||
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
|
||||
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ResolveResponsesSupport(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveResponsesSupport(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
want bool
|
||||
}{
|
||||
// 关键不变量:未探测必须返回 true(保留旧行为)
|
||||
{"unknown defaults to true (preserve old behavior)", nil, true},
|
||||
{"unknown empty defaults to true", map[string]any{}, true},
|
||||
{"unknown wrong type defaults to true", map[string]any{ExtraKeyResponsesSupported: "yes"}, true},
|
||||
|
||||
// 已探测:标记决定
|
||||
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
|
||||
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShouldUseResponsesAPI(tc.extra)
|
||||
if got != tc.want {
|
||||
t.Errorf("ShouldUseResponsesAPI(%v) = %v, want %v", tc.extra, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,34 @@ const (
|
||||
|
||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||
|
||||
const affiliateUserOverviewSQL = `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.aff_code,
|
||||
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
|
||||
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
|
||||
ua.aff_count,
|
||||
COALESCE(rebated.rebated_invitee_count, 0),
|
||||
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM user_affiliates ua
|
||||
JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND source_user_id IS NOT NULL
|
||||
GROUP BY user_id
|
||||
) rebated ON rebated.user_id = ua.user_id
|
||||
LEFT JOIN (
|
||||
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
|
||||
FROM user_affiliate_ledger
|
||||
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
|
||||
GROUP BY user_id
|
||||
) matured ON matured.user_id = ua.user_id
|
||||
WHERE ua.user_id = $1
|
||||
LIMIT 1`
|
||||
|
||||
type affiliateQueryExecer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
@@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
@@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
|
||||
|
||||
if freezeHours > 0 {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, freezeHours); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -275,9 +303,32 @@ FROM cleared`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
||||
INSERT INTO user_affiliate_ledger (
|
||||
user_id,
|
||||
action,
|
||||
amount,
|
||||
source_user_id,
|
||||
balance_after,
|
||||
aff_quota_after,
|
||||
aff_frozen_quota_after,
|
||||
aff_history_quota_after,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
|
||||
userID,
|
||||
transferred,
|
||||
snapshot.BalanceAfter,
|
||||
snapshot.AvailableQuotaAfter,
|
||||
snapshot.FrozenQuotaAfter,
|
||||
snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||
}
|
||||
|
||||
@@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
|
||||
})
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
`+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"aff_code": "inviter_aff.aff_code",
|
||||
"total_rebate": "total_rebate",
|
||||
"created_at": "ua.created_at",
|
||||
}, "ua.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ua.inviter_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ua.user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
COALESCE(inviter_aff.aff_code, ''),
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
|
||||
ua.created_at
|
||||
FROM user_affiliates ua
|
||||
JOIN users invitee ON invitee.id = ua.user_id
|
||||
JOIN users inviter ON inviter.id = ua.inviter_id
|
||||
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = ua.inviter_id
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
`+where+`
|
||||
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateInviteRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInviteRecord
|
||||
if err := rows.Scan(
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.AffCode,
|
||||
&item.TotalRebate,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
|
||||
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN payment_orders po ON po.id = ual.source_order_id
|
||||
JOIN users invitee ON invitee.id = ual.source_user_id
|
||||
JOIN users inviter ON inviter.id = ual.user_id
|
||||
WHERE ual.action = 'accrue'
|
||||
AND ual.source_order_id IS NOT NULL`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"order": "po.id",
|
||||
"inviter": "inviter.email",
|
||||
"invitee": "invitee.email",
|
||||
"order_amount": "po.amount",
|
||||
"pay_amount": "po.pay_amount",
|
||||
"rebate_amount": "ual.amount",
|
||||
"payment_type": "po.payment_type",
|
||||
"order_status": "po.status",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT po.id,
|
||||
po.out_trade_no,
|
||||
ual.user_id,
|
||||
COALESCE(inviter.email, ''),
|
||||
COALESCE(inviter.username, ''),
|
||||
ual.source_user_id,
|
||||
COALESCE(invitee.email, ''),
|
||||
COALESCE(invitee.username, ''),
|
||||
po.amount::double precision,
|
||||
po.pay_amount::double precision,
|
||||
ual.amount::double precision,
|
||||
po.payment_type,
|
||||
po.status,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateRebateRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateRebateRecord
|
||||
if err := rows.Scan(
|
||||
&item.OrderID,
|
||||
&item.OutTradeNo,
|
||||
&item.InviterID,
|
||||
&item.InviterEmail,
|
||||
&item.InviterUsername,
|
||||
&item.InviteeID,
|
||||
&item.InviteeEmail,
|
||||
&item.InviteeUsername,
|
||||
&item.OrderAmount,
|
||||
&item.PayAmount,
|
||||
&item.RebateAmount,
|
||||
&item.PaymentType,
|
||||
&item.OrderStatus,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
|
||||
"u.email", "u.username", "u.id::text",
|
||||
})
|
||||
baseJoin := `
|
||||
FROM user_affiliate_ledger ual
|
||||
JOIN users u ON u.id = ual.user_id
|
||||
WHERE ual.action = 'transfer'`
|
||||
if where != "" {
|
||||
where = strings.Replace(where, "WHERE ", " AND ", 1)
|
||||
}
|
||||
|
||||
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
|
||||
"user": "u.email",
|
||||
"amount": "ual.amount",
|
||||
"balance_after": "ual.balance_after",
|
||||
"available_quota_after": "ual.aff_quota_after",
|
||||
"frozen_quota_after": "ual.aff_frozen_quota_after",
|
||||
"history_quota_after": "ual.aff_history_quota_after",
|
||||
"created_at": "ual.created_at",
|
||||
}, "ual.created_at")
|
||||
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ual.id,
|
||||
ual.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ual.amount::double precision,
|
||||
ual.balance_after::double precision,
|
||||
ual.aff_quota_after::double precision,
|
||||
ual.aff_frozen_quota_after::double precision,
|
||||
ual.aff_history_quota_after::double precision,
|
||||
ual.created_at
|
||||
`+baseJoin+where+`
|
||||
`+orderBy+`
|
||||
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.AffiliateTransferRecord, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateTransferRecord
|
||||
var balanceAfter sql.NullFloat64
|
||||
var availableQuotaAfter sql.NullFloat64
|
||||
var frozenQuotaAfter sql.NullFloat64
|
||||
var historyQuotaAfter sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&item.LedgerID,
|
||||
&item.UserID,
|
||||
&item.UserEmail,
|
||||
&item.Username,
|
||||
&item.Amount,
|
||||
&balanceAfter,
|
||||
&availableQuotaAfter,
|
||||
&frozenQuotaAfter,
|
||||
&historyQuotaAfter,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
|
||||
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
|
||||
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
|
||||
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
|
||||
item.SnapshotAvailable = balanceAfter.Valid &&
|
||||
availableQuotaAfter.Valid &&
|
||||
frozenQuotaAfter.Valid &&
|
||||
historyQuotaAfter.Valid
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var overview service.AffiliateUserOverview
|
||||
var customRate float64
|
||||
var hasCustomRate bool
|
||||
if err := rows.Scan(
|
||||
&overview.UserID,
|
||||
&overview.Email,
|
||||
&overview.Username,
|
||||
&overview.AffCode,
|
||||
&customRate,
|
||||
&hasCustomRate,
|
||||
&overview.InvitedCount,
|
||||
&overview.RebatedInviteeCount,
|
||||
&overview.AvailableQuota,
|
||||
&overview.HistoryQuota,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasCustomRate {
|
||||
overview.RebateRatePercent = customRate
|
||||
overview.RebateRateCustom = true
|
||||
}
|
||||
return &overview, rows.Err()
|
||||
}
|
||||
|
||||
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
|
||||
clauses := make([]string, 0, 3)
|
||||
args := make([]any, 0, 3)
|
||||
if filter.StartAt != nil {
|
||||
args = append(args, *filter.StartAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
|
||||
}
|
||||
if filter.EndAt != nil {
|
||||
args = append(args, *filter.EndAt)
|
||||
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
|
||||
}
|
||||
search := strings.TrimSpace(filter.Search)
|
||||
if search != "" && len(searchColumns) > 0 {
|
||||
args = append(args, "%"+strings.ToLower(search)+"%")
|
||||
parts := make([]string, 0, len(searchColumns))
|
||||
for _, col := range searchColumns {
|
||||
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
|
||||
}
|
||||
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
|
||||
}
|
||||
if len(clauses) == 0 {
|
||||
return "", args
|
||||
}
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
|
||||
column := sortColumns[filter.SortBy]
|
||||
if column == "" {
|
||||
column = fallbackColumn
|
||||
}
|
||||
direction := "DESC"
|
||||
if !filter.SortDesc {
|
||||
direction = "ASC"
|
||||
}
|
||||
return "ORDER BY " + column + " " + direction + " NULLS LAST"
|
||||
}
|
||||
|
||||
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
return 0, rows.Err()
|
||||
}
|
||||
var total int64
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
@@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
type affiliateTransferSnapshot struct {
|
||||
BalanceAfter float64
|
||||
AvailableQuotaAfter float64
|
||||
FrozenQuotaAfter float64
|
||||
HistoryQuotaAfter float64
|
||||
}
|
||||
|
||||
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT u.balance::double precision,
|
||||
ua.aff_quota::double precision,
|
||||
ua.aff_frozen_quota::double precision,
|
||||
ua.aff_history_quota::double precision
|
||||
FROM users u
|
||||
JOIN user_affiliates ua ON ua.user_id = u.id
|
||||
WHERE u.id = $1
|
||||
LIMIT 1`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
var snapshot affiliateTransferSnapshot
|
||||
if err := rows.Scan(
|
||||
&snapshot.BalanceAfter,
|
||||
&snapshot.AvailableQuotaAfter,
|
||||
&snapshot.FrozenQuotaAfter,
|
||||
&snapshot.HistoryQuotaAfter,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &snapshot, rows.Err()
|
||||
}
|
||||
|
||||
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
return &v.Float64
|
||||
}
|
||||
|
||||
func generateAffiliateCode() (string, error) {
|
||||
buf := make([]byte, affiliateCodeLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
@@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
|
||||
return *v
|
||||
}
|
||||
|
||||
func nullableInt64Arg(v *int64) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
|
||||
//
|
||||
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
|
||||
|
||||
@@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
||||
ledgerCount := querySingleInt(t, txCtx, client,
|
||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||
require.Equal(t, 1, ledgerCount)
|
||||
|
||||
rows, err := client.QueryContext(txCtx, `
|
||||
SELECT amount::double precision,
|
||||
balance_after::double precision,
|
||||
aff_quota_after::double precision,
|
||||
aff_frozen_quota_after::double precision,
|
||||
aff_history_quota_after::double precision
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1 AND action = 'transfer'
|
||||
LIMIT 1`, u.ID)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
require.True(t, rows.Next(), "expected transfer ledger")
|
||||
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
|
||||
require.NoError(t, rows.Scan(&amount, &balanceAfter, "aAfter, &frozenAfter, &historyAfter))
|
||||
require.InDelta(t, 12.34, amount, 1e-9)
|
||||
require.InDelta(t, 17.84, balanceAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, quotaAfter, 1e-9)
|
||||
require.InDelta(t, 0.0, frozenAfter, 1e-9)
|
||||
require.InDelta(t, 12.34, historyAfter, 1e-9)
|
||||
}
|
||||
|
||||
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||
@@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, bound, "invitee must bind to inviter")
|
||||
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||
|
||||
|
||||
28
backend/internal/repository/affiliate_repo_test.go
Normal file
28
backend/internal/repository/affiliate_repo_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
|
||||
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
|
||||
|
||||
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
|
||||
require.Contains(t, query, "frozen_until <= NOW()")
|
||||
}
|
||||
|
||||
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
|
||||
source, err := os.ReadFile("affiliate_repo.go")
|
||||
require.NoError(t, err)
|
||||
content := string(source)
|
||||
|
||||
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
|
||||
require.Contains(t, content, "ual.amount::double precision")
|
||||
require.Contains(t, content, "ual.balance_after::double precision")
|
||||
require.NotContains(t, content, "parseAffiliateRebateAmount")
|
||||
require.NotContains(t, content, `"current_balance": "u.balance"`)
|
||||
}
|
||||
@@ -449,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
|
||||
SessionWindowStart: account.SessionWindowStart,
|
||||
SessionWindowEnd: account.SessionWindowEnd,
|
||||
SessionWindowStatus: account.SessionWindowStatus,
|
||||
AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
|
||||
GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
|
||||
Credentials: filterSchedulerCredentials(account.Credentials),
|
||||
Extra: filterSchedulerExtra(account.Extra),
|
||||
}
|
||||
}
|
||||
|
||||
func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
|
||||
if len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]service.AccountGroup, 0, len(accountGroups))
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, service.AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
|
||||
if len(groupIDs) == 0 && len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
|
||||
filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[ag.GroupID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[ag.GroupID] = struct{}{}
|
||||
filtered = append(filtered, ag.GroupID)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
|
||||
if len(credentials) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
SessionWindowStart: &now,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
SessionWindowStatus: "active",
|
||||
GroupIDs: []int64{bucket.GroupID},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 101,
|
||||
GroupID: bucket.GroupID,
|
||||
Priority: 5,
|
||||
Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
|
||||
@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
require.Equal(t, 4, got.GetMaxSessions())
|
||||
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 1)
|
||||
require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
|
||||
full, err := cache.GetAccount(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, full)
|
||||
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
|
||||
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
|
||||
require.Len(t, full.AccountGroups, 1)
|
||||
require.NotNil(t, full.AccountGroups[0].Group)
|
||||
}
|
||||
|
||||
@@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
|
||||
require.Equal(t, true, got.Extra["mixed_scheduling"])
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
}
|
||||
|
||||
func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
|
||||
account := service.Account{
|
||||
ID: 42,
|
||||
Platform: service.PlatformAnthropic,
|
||||
GroupIDs: []int64{7, 9, 7, 0},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 7,
|
||||
Priority: 2,
|
||||
Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
|
||||
Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 11,
|
||||
Priority: 3,
|
||||
Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 0,
|
||||
Priority: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := buildSchedulerMetadataAccount(account)
|
||||
|
||||
require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 2)
|
||||
require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
|
||||
require.Equal(t, 2, got.AccountGroups[0].Priority)
|
||||
require.Nil(t, got.AccountGroups[0].Account)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
|
||||
require.Nil(t, got.Groups)
|
||||
}
|
||||
|
||||
@@ -740,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
@@ -934,6 +935,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
"payment_visible_method_alipay_source": "",
|
||||
"payment_visible_method_wxpay_source": "",
|
||||
|
||||
@@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
affiliates := admin.Group("/affiliates")
|
||||
{
|
||||
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
|
||||
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
|
||||
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
|
||||
|
||||
users := affiliates.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.Affiliate.ListUsers)
|
||||
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
|
||||
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
|
||||
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
|
||||
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
|
||||
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -554,7 +555,16 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
// 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。
|
||||
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
|
||||
// TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。
|
||||
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.sendErrorAndEnd(c,
|
||||
"账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+
|
||||
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
|
||||
)
|
||||
}
|
||||
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
86
backend/internal/service/admin_balance_history_test.go
Normal file
86
backend/internal/service/admin_balance_history_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
older := now.Add(-2 * time.Hour)
|
||||
newer := now.Add(time.Hour)
|
||||
|
||||
usedBy := int64(10)
|
||||
redeemCodes := []RedeemCode{
|
||||
{
|
||||
ID: 1,
|
||||
Type: RedeemTypeBalance,
|
||||
Value: 8,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &now,
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Type: RedeemTypeConcurrency,
|
||||
Value: 1,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &older,
|
||||
CreatedAt: older,
|
||||
},
|
||||
}
|
||||
affiliateCodes := []RedeemCode{
|
||||
{
|
||||
ID: -20,
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: 3.5,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &newer,
|
||||
CreatedAt: newer,
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 2,
|
||||
})
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
|
||||
require.Equal(t, RedeemTypeBalance, got[1].Type)
|
||||
}
|
||||
|
||||
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
|
||||
usedBy := int64(10)
|
||||
at := func(hours int) *time.Time {
|
||||
v := base.Add(time.Duration(hours) * time.Hour)
|
||||
return &v
|
||||
}
|
||||
|
||||
got := mergeBalanceHistoryCodes(
|
||||
[]RedeemCode{
|
||||
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
|
||||
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
|
||||
},
|
||||
[]RedeemCode{
|
||||
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
|
||||
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
|
||||
},
|
||||
pagination.PaginationParams{Page: 2, PageSize: 2},
|
||||
)
|
||||
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
|
||||
require.Equal(t, int64(-4), got[1].ID)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -973,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
if codeType == RedeemTypeAffiliateBalance {
|
||||
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
if codeType == "" {
|
||||
return s.getAllUserBalanceHistory(ctx, userID, params)
|
||||
}
|
||||
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
total := result.Total
|
||||
// Aggregate total recharged amount (only once, regardless of type filter)
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, result.Total, totalRecharged, nil
|
||||
return codes, total, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
|
||||
needed := params.Offset() + params.Limit()
|
||||
if needed < params.Limit() {
|
||||
needed = params.Limit()
|
||||
}
|
||||
|
||||
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
|
||||
|
||||
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result != nil {
|
||||
total = result.Total
|
||||
}
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
|
||||
if needed <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
out []RedeemCode
|
||||
total int64
|
||||
)
|
||||
for page := 1; len(out) < needed; page++ {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: 1000}
|
||||
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = currentTotal
|
||||
out = append(out, codes...)
|
||||
if len(codes) < params.Limit() || int64(len(out)) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > needed {
|
||||
out = out[:needed]
|
||||
}
|
||||
return out, total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
rows, err := s.entClient.QueryContext(ctx, `
|
||||
SELECT id,
|
||||
amount::double precision,
|
||||
created_at
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'
|
||||
ORDER BY created_at DESC, id DESC
|
||||
OFFSET $2
|
||||
LIMIT $3`, userID, params.Offset(), params.Limit())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
codes := make([]RedeemCode, 0, params.Limit())
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var amount float64
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
usedBy := userID
|
||||
usedAt := createdAt
|
||||
codes = append(codes, RedeemCode{
|
||||
ID: -id,
|
||||
Code: fmt.Sprintf("AFF-%d", id),
|
||||
Type: RedeemTypeAffiliateBalance,
|
||||
Value: amount,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &usedBy,
|
||||
UsedAt: &usedAt,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return codes, total, nil
|
||||
}
|
||||
|
||||
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_affiliate_ledger
|
||||
WHERE user_id = $1
|
||||
AND action = 'transfer'`, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var total sql.NullInt64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !total.Valid {
|
||||
return 0, nil
|
||||
}
|
||||
return total.Int64, nil
|
||||
}
|
||||
|
||||
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
|
||||
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
|
||||
sort.SliceStable(combined, func(i, j int) bool {
|
||||
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
|
||||
})
|
||||
offset := params.Offset()
|
||||
if offset >= len(combined) {
|
||||
return []RedeemCode{}
|
||||
}
|
||||
end := offset + params.Limit()
|
||||
if end > len(combined) {
|
||||
end = len(combined)
|
||||
}
|
||||
return combined[offset:end]
|
||||
}
|
||||
|
||||
func redeemCodeHistoryTime(code RedeemCode) time.Time {
|
||||
if code.UsedAt != nil {
|
||||
return *code.UsedAt
|
||||
}
|
||||
return code.CreatedAt
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
||||
|
||||
@@ -98,7 +98,7 @@ type AffiliateRepository interface {
|
||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
@@ -110,6 +110,10 @@ type AffiliateRepository interface {
|
||||
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
|
||||
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
|
||||
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
|
||||
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
|
||||
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
|
||||
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
|
||||
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
|
||||
}
|
||||
|
||||
// AffiliateAdminFilter 列表筛选条件
|
||||
@@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
|
||||
AffCount int `json:"aff_count"`
|
||||
}
|
||||
|
||||
type AffiliateRecordFilter struct {
|
||||
Search string
|
||||
Page int
|
||||
PageSize int
|
||||
StartAt *time.Time
|
||||
EndAt *time.Time
|
||||
SortBy string
|
||||
SortDesc bool
|
||||
}
|
||||
|
||||
type AffiliateInviteRecord struct {
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateRebateRecord struct {
|
||||
OrderID int64 `json:"order_id"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
InviterID int64 `json:"inviter_id"`
|
||||
InviterEmail string `json:"inviter_email"`
|
||||
InviterUsername string `json:"inviter_username"`
|
||||
InviteeID int64 `json:"invitee_id"`
|
||||
InviteeEmail string `json:"invitee_email"`
|
||||
InviteeUsername string `json:"invitee_username"`
|
||||
OrderAmount float64 `json:"order_amount"`
|
||||
PayAmount float64 `json:"pay_amount"`
|
||||
RebateAmount float64 `json:"rebate_amount"`
|
||||
PaymentType string `json:"payment_type"`
|
||||
OrderStatus string `json:"order_status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateTransferRecord struct {
|
||||
LedgerID int64 `json:"ledger_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
Username string `json:"username"`
|
||||
Amount float64 `json:"amount"`
|
||||
BalanceAfter *float64 `json:"balance_after,omitempty"`
|
||||
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
|
||||
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
|
||||
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
|
||||
SnapshotAvailable bool `json:"snapshot_available"`
|
||||
CurrentBalance float64 `json:"-"`
|
||||
RemainingQuota float64 `json:"-"`
|
||||
FrozenQuota float64 `json:"-"`
|
||||
HistoryQuota float64 `json:"-"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type AffiliateUserOverview struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
AffCode string `json:"aff_code"`
|
||||
RebateRatePercent float64 `json:"rebate_rate_percent"`
|
||||
RebateRateCustom bool `json:"-"`
|
||||
InvitedCount int `json:"invited_count"`
|
||||
RebatedInviteeCount int `json:"rebated_invitee_count"`
|
||||
AvailableQuota float64 `json:"available_quota"`
|
||||
HistoryQuota float64 `json:"history_quota"`
|
||||
}
|
||||
|
||||
type AffiliateService struct {
|
||||
repo AffiliateRepository
|
||||
settingService *SettingService
|
||||
@@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
|
||||
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
|
||||
}
|
||||
return s.repo.ListUsersWithCustomSettings(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
|
||||
}
|
||||
|
||||
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
|
||||
if userID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
|
||||
}
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if overview != nil {
|
||||
if !overview.RebateRateCustom {
|
||||
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
|
||||
}
|
||||
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
|
||||
}
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 20
|
||||
}
|
||||
if filter.PageSize > 100 {
|
||||
filter.PageSize = 100
|
||||
}
|
||||
filter.Search = strings.TrimSpace(filter.Search)
|
||||
filter.SortBy = strings.TrimSpace(filter.SortBy)
|
||||
return filter
|
||||
}
|
||||
|
||||
@@ -226,6 +226,12 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 7.5e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-7,
|
||||
OutputPricePerToken: 1.25e-6,
|
||||
CacheReadPricePerToken: 2e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
@@ -295,6 +301,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
return s.fallbackPrices["gpt-5.5"]
|
||||
case "gpt-5.4-mini":
|
||||
return s.fallbackPrices["gpt-5.4-mini"]
|
||||
case "gpt-5.4-nano":
|
||||
return s.fallbackPrices["gpt-5.4-nano"]
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
|
||||
@@ -51,10 +51,11 @@ const (
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeBalance = domain.RedeemTypeBalance
|
||||
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
|
||||
RedeemTypeSubscription = domain.RedeemTypeSubscription
|
||||
RedeemTypeInvitation = domain.RedeemTypeInvitation
|
||||
RedeemTypeAffiliateBalance = "affiliate_balance"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
@@ -336,6 +337,8 @@ const (
|
||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
|
||||
|
||||
// Balance Low Notification
|
||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||
|
||||
@@ -1,13 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type gatewayTTLSettingRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
if r == nil {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
v, ok := r.data[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
r.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
if v, ok := r.data[key]; ok {
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.data[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for key, value := range r.data {
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
|
||||
if r != nil {
|
||||
delete(r.data, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||
}
|
||||
|
||||
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
||||
|
||||
result := injectAnthropicCacheControlTTL1h(body)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||
|
||||
target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget5m, target)
|
||||
|
||||
account.Extra = map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "1h",
|
||||
}
|
||||
target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget1h, target)
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
|
||||
|
||||
repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
}
|
||||
|
||||
@@ -62,6 +62,11 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTTLTarget5m = "5m"
|
||||
cacheTTLTarget1h = "1h"
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -654,15 +659,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
|
||||
uid := ParseMetadataUserID(parsed.MetadataUserID)
|
||||
if uid != nil && uid.SessionID != "" {
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "metadata_user_id",
|
||||
"session_id", uid.SessionID,
|
||||
"device_id", uid.DeviceID,
|
||||
"is_new_format", uid.IsNewFormat,
|
||||
)
|
||||
return uid.SessionID
|
||||
}
|
||||
slog.Info("sticky.hash_metadata_parse_failed",
|
||||
"metadata_user_id", parsed.MetadataUserID,
|
||||
"parsed_nil", uid == nil,
|
||||
)
|
||||
}
|
||||
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
hash := s.hashContent(cacheableContent)
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "cacheable_content",
|
||||
"hash", hash,
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
@@ -702,7 +723,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
}
|
||||
if combined.Len() > 0 {
|
||||
return s.hashContent(combined.String())
|
||||
hash := s.hashContent(combined.String())
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "message_content_fallback",
|
||||
"hash", hash,
|
||||
"content_len", combined.Len(),
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
return ""
|
||||
@@ -1406,14 +1433,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
var stickySource string
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
stickySource = "prefetch"
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
stickySource = "cache"
|
||||
}
|
||||
}
|
||||
|
||||
// [DEBUG-STICKY] 调度器入口日志
|
||||
slog.Info("sticky.scheduler_entry",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"session_hash", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"sticky_source", stickySource,
|
||||
"model", requestedModel,
|
||||
"load_batch", cfg.LoadBatchEnabled,
|
||||
"has_concurrency_svc", s.concurrencyService != nil,
|
||||
"excluded_count", len(excludedIDs),
|
||||
)
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
@@ -1589,6 +1631,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && stickyAccountID > 0 {
|
||||
slog.Debug("sticky.layer1_5_checking",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"in_routing_list", containsInt64(routingAccountIDs, stickyAccountID),
|
||||
"is_excluded", isExcluded(stickyAccountID),
|
||||
"in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
@@ -1612,6 +1661,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_hit",
|
||||
"account_id", stickyAccountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
@@ -1762,27 +1816,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||
if clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_clear",
|
||||
"account_id", accountID,
|
||||
"reason", "should_clear_sticky_session",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
// 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的
|
||||
// accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存
|
||||
// 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。
|
||||
platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed)
|
||||
modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)
|
||||
modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel)
|
||||
quotaOK := s.isAccountSchedulableForQuota(account)
|
||||
windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true)
|
||||
rpmOK := s.isAccountSchedulableForRPM(ctx, account, true)
|
||||
schedulable := s.isAccountSchedulableForSelection(account)
|
||||
|
||||
slog.Debug("sticky.layer1_5_no_routing_checks",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"clear_sticky", clearSticky,
|
||||
"schedulable", schedulable,
|
||||
"platform_ok", platformOK,
|
||||
"model_supported", modelSupported,
|
||||
"model_schedulable", modelSchedulable,
|
||||
"quota_ok", quotaOK,
|
||||
"window_cost_ok", windowCostOK,
|
||||
"rpm_ok", rpmOK,
|
||||
)
|
||||
|
||||
if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "session_limit",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_slot_busy",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@@ -1791,6 +1883,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "wait_plan",
|
||||
)
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
@@ -1799,12 +1896,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if !clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "gate_check_failed",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "account_not_in_map",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if len(routingAccountIDs) == 0 && sessionHash != "" {
|
||||
slog.Debug("sticky.layer1_5_no_routing_skip",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"reason", func() string {
|
||||
if stickyAccountID == 0 {
|
||||
return "no_sticky_binding"
|
||||
}
|
||||
return "sticky_account_excluded"
|
||||
}(),
|
||||
)
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
slog.Debug("sticky.layer2_fallback",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"reason", "sticky_not_used_falling_back_to_load_balance",
|
||||
"total_accounts", len(accounts),
|
||||
)
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -4104,6 +4231,87 @@ func enforceCacheControlLimit(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
|
||||
// 仅修改已经存在的 cache_control,不新增缓存断点。
|
||||
func injectAnthropicCacheControlTTL1h(body []byte) []byte {
|
||||
return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
|
||||
}
|
||||
|
||||
func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
|
||||
if len(body) == 0 || ttl == "" {
|
||||
return body
|
||||
}
|
||||
out := body
|
||||
var paths []string
|
||||
addPath := func(path string, value gjson.Result) {
|
||||
cc := value.Get("cache_control")
|
||||
if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
|
||||
return
|
||||
}
|
||||
if cc.Get("ttl").String() == ttl {
|
||||
return
|
||||
}
|
||||
paths = append(paths, path+".cache_control.ttl")
|
||||
}
|
||||
|
||||
if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
|
||||
paths = append(paths, "cache_control.ttl")
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(body, "system")
|
||||
if system.IsArray() {
|
||||
idx := -1
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("system.%d", idx), block)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if messages.IsArray() {
|
||||
msgIdx := -1
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
msgIdx++
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
contentIdx := -1
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
contentIdx++
|
||||
addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.IsArray() {
|
||||
idx := -1
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("tools.%d", idx), tool)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if next, err := sjson.SetBytes(out, path, ttl); err == nil {
|
||||
out = next
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
|
||||
if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -4263,6 +4471,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
|
||||
body = injectAnthropicCacheControlTTL1h(body)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -7103,9 +7315,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
@@ -7512,6 +7724,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
return account.GetCacheTTLOverrideTarget(), true
|
||||
}
|
||||
if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
|
||||
return cacheTTLTarget5m, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
@@ -7548,9 +7773,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
|
||||
@@ -7949,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
@@ -8118,10 +8350,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
applyCacheTTLOverride(&result.Usage, overrideTarget)
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type upstreamContextTestKey string
|
||||
|
||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
|
||||
require.Equal(t, 3, result.usage.InputTokens)
|
||||
require.Equal(t, 7, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
|
||||
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
|
||||
upstreamCtx, release := detachUpstreamContext(parent)
|
||||
defer release()
|
||||
|
||||
cancel()
|
||||
|
||||
require.NoError(t, upstreamCtx.Err())
|
||||
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
|
||||
}
|
||||
|
||||
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
149
backend/internal/service/openai_apikey_responses_probe.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
)
|
||||
|
||||
// openaiResponsesProbeTimeout 是探测请求的超时时长。
|
||||
// 探测必须快速失败——超时不应阻塞账号创建/更新流程。
|
||||
const openaiResponsesProbeTimeout = 8 * time.Second
|
||||
|
||||
// openaiResponsesProbePayload 是探测使用的最小 Responses 请求体。
|
||||
// 仅作能力探测,不期望响应内容质量;Stream=false 减少 SSE 解析开销。
|
||||
//
|
||||
// 注意:探测的目标是区分"端点存在"与"端点不存在"——只要上游返回非 404 的
|
||||
// 4xx/5xx(如 400 invalid_request_error / 401 unauthorized / 422 等),
|
||||
// 都视为"端点存在 → 支持 Responses"。仅 404 / 405 视为"端点不存在"。
|
||||
func openaiResponsesProbePayload(modelID string) []byte {
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
modelID = openai.DefaultTestModel
|
||||
}
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{"type": "input_text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"instructions": openai.DefaultInstructions,
|
||||
"stream": false,
|
||||
})
|
||||
return body
|
||||
}
|
||||
|
||||
// ProbeOpenAIAPIKeyResponsesSupport 探测 OpenAI APIKey 账号上游是否支持
|
||||
// /v1/responses 端点,并将结果持久化到 accounts.extra.openai_responses_supported。
|
||||
//
|
||||
// 调用时机:账号创建/更新后,且仅当 platform=openai && type=apikey 时。
|
||||
//
|
||||
// 探测策略(参见包文档 internal/pkg/openai_compat):
|
||||
// - 上游 404 / 405 → 不支持,写 false
|
||||
// - 上游 2xx / 其他 4xx(401/422/400 等)/ 5xx → 支持,写 true
|
||||
// - 网络层失败(连接错误、超时)→ 不写标记,保持 unknown
|
||||
// (后续请求仍按"现状即证据"默认走 Responses)
|
||||
//
|
||||
// 该方法是幂等的:重复调用会以最新探测结果覆盖标记。
|
||||
//
|
||||
// 关于失败处理:探测本身的失败不应阻塞账号创建——账号能创建/更新成功就够了,
|
||||
// 探测结果只影响后续路由优化。所有错误都仅记录日志,不向调用方传播。
|
||||
func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Context, accountID int64) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_load_account_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeAPIKey {
|
||||
// 仅 OpenAI APIKey 账号需要探测;其他账号类型无能力差异。
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_skip_no_apikey: account_id=%d", accountID)
|
||||
return
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_invalid_baseurl: account_id=%d base_url=%q err=%v", accountID, baseURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
probeURL := buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, openaiResponsesProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(probeCtx, http.MethodPost, probeURL, bytes.NewReader(openaiResponsesProbePayload("")))
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
// 网络层失败:不写标记,保持 unknown,下次重试或由网关 fallback 处理
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_request_failed: account_id=%d url=%s err=%v", accountID, probeURL, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
supported := isResponsesEndpointSupportedByStatus(resp.StatusCode)
|
||||
|
||||
if err := s.accountRepo.UpdateExtra(ctx, accountID, map[string]any{
|
||||
openai_compat.ExtraKeyResponsesSupported: supported,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.openai_probe", "probe_persist_failed: account_id=%d supported=%v err=%v", accountID, supported, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.openai_probe",
|
||||
"probe_done: account_id=%d base_url=%s status=%d supported=%v",
|
||||
accountID, normalizedBaseURL, resp.StatusCode, supported,
|
||||
)
|
||||
}
|
||||
|
||||
// isResponsesEndpointSupportedByStatus 根据探测响应的 HTTP 状态码判定上游
|
||||
// 是否暴露 /v1/responses 端点。
|
||||
//
|
||||
// 关键观察:第三方 OpenAI 兼容上游(DeepSeek/Kimi 等)对未知端点统一返回 404
|
||||
// 或 405;而 OpenAI 官方/有 Responses 实现的上游会因为请求体最简(缺字段)
|
||||
// 返回 400/422 等业务错误,但端点本身存在。
|
||||
//
|
||||
// 因此:仅 404 和 405 视为"端点不存在",其他 status 视为"端点存在"。
|
||||
//
|
||||
// 5xx 也视为"端点存在"——上游偶发故障不应误判为不支持。
|
||||
func isResponsesEndpointSupportedByStatus(status int) bool {
|
||||
switch status {
|
||||
case http.StatusNotFound, http.StatusMethodNotAllowed:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -38,6 +38,29 @@ var codexModelMap = map[string]string{
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5": "gpt-5.4",
|
||||
"gpt-5-mini": "gpt-5.4",
|
||||
"gpt-5-nano": "gpt-5.4",
|
||||
"gpt-5.1": "gpt-5.4",
|
||||
"gpt-5.1-codex": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.3-codex",
|
||||
"gpt-5.2-codex": "gpt-5.2",
|
||||
"codex-mini-latest": "gpt-5.3-codex",
|
||||
"gpt-5-codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
var codexVersionModelPrefixes = []struct {
|
||||
prefix string
|
||||
target string
|
||||
}{
|
||||
{prefix: "gpt-5.3-codex-spark", target: "gpt-5.3-codex-spark"},
|
||||
{prefix: "gpt-5.3-codex", target: "gpt-5.3-codex"},
|
||||
{prefix: "gpt-5.4-mini", target: "gpt-5.4-mini"},
|
||||
{prefix: "gpt-5.4-nano", target: "gpt-5.4-nano"},
|
||||
{prefix: "gpt-5.5", target: "gpt-5.5"},
|
||||
{prefix: "gpt-5.4", target: "gpt-5.4"},
|
||||
{prefix: "gpt-5.2", target: "gpt-5.2"},
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
@@ -447,8 +470,19 @@ func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if mapped, ok := normalizeKnownCodexModel(model); ok {
|
||||
return mapped
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func normalizeKnownCodexModel(model string) (string, bool) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return model
|
||||
return model, true
|
||||
}
|
||||
|
||||
modelID := model
|
||||
@@ -457,41 +491,58 @@ func normalizeCodexModel(model string) string {
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||
return mapped
|
||||
key := codexModelLookupKey(modelID)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
if mapped := getNormalizedCodexModel(key); mapped != "" {
|
||||
return mapped, true
|
||||
}
|
||||
for _, item := range codexVersionModelPrefixes {
|
||||
if key == item.prefix {
|
||||
return item.target, true
|
||||
}
|
||||
suffix, ok := strings.CutPrefix(key, item.prefix+"-")
|
||||
if ok && isKnownCodexModelSuffix(suffix) {
|
||||
return item.target, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
func codexModelLookupKey(modelID string) string {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
return strings.ToLower(strings.Join(strings.Fields(modelID), "-"))
|
||||
}
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
|
||||
return "gpt-5.5"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
|
||||
return "gpt-5.4-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
|
||||
return "gpt-5.3-codex-spark"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.4"
|
||||
func isKnownCodexModelSuffix(suffix string) bool {
|
||||
switch suffix {
|
||||
case "none", "minimal", "low", "medium", "high", "xhigh":
|
||||
return true
|
||||
}
|
||||
return isCodexDateSuffix(suffix)
|
||||
}
|
||||
|
||||
return "gpt-5.4"
|
||||
func isCodexDateSuffix(suffix string) bool {
|
||||
parts := strings.Split(suffix, "-")
|
||||
if len(parts) != 3 || len(parts[0]) != 4 || len(parts[1]) != 2 || len(parts[2]) != 2 {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, r := range part {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isCodexSparkModel(model string) bool {
|
||||
@@ -789,18 +840,13 @@ func SupportsVerbosity(model string) bool {
|
||||
}
|
||||
|
||||
func getNormalizedCodexModel(modelID string) string {
|
||||
if modelID == "" {
|
||||
key := codexModelLookupKey(modelID)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
if mapped, ok := codexModelMap[modelID]; ok {
|
||||
if mapped, ok := codexModelMap[key]; ok {
|
||||
return mapped
|
||||
}
|
||||
lower := strings.ToLower(modelID)
|
||||
for key, value := range codexModelMap {
|
||||
if strings.ToLower(key) == lower {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,16 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@@ -18,6 +21,51 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAICompatFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
type openAICompatBlockingReadCloser struct {
|
||||
data []byte
|
||||
offset int
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
|
||||
return &openAICompatBlockingReadCloser{
|
||||
data: data,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
|
||||
if r.offset < len(r.data) {
|
||||
n := copy(p, r.data[r.offset:])
|
||||
r.offset += n
|
||||
return n, nil
|
||||
}
|
||||
<-r.closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *openAICompatBlockingReadCloser) Close() error {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||
"",
|
||||
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := "data: [DONE]\n\n"
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
require.Zero(t, result.Usage.InputTokens)
|
||||
require.Zero(t, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
@@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
|
||||
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
|
||||
}
|
||||
|
||||
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
|
||||
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
|
||||
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
|
||||
require.NoError(t, firstErr)
|
||||
require.Nil(t, firstBlocked)
|
||||
meta.initFromFirstFrame(firstOut)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load())
|
||||
|
||||
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
meta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := meta.requestModelForFrame(payload)
|
||||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||
if model == "" {
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
meta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
}
|
||||
|
||||
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
|
||||
require.NoError(t, errSession)
|
||||
require.Nil(t, blockedSession)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata")
|
||||
|
||||
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errCancel)
|
||||
require.Nil(t, blockedCancel)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
|
||||
|
||||
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errFlat)
|
||||
require.Nil(t, blockedFlat)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
|
||||
|
||||
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
|
||||
require.NoError(t, errClear)
|
||||
require.Nil(t, blockedClear)
|
||||
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
|
||||
}
|
||||
|
||||
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
|
||||
// "block keeps previous" semantic: when policy returns block on a
|
||||
// response.create frame, that frame is never sent upstream, so billing tier
|
||||
|
||||
@@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
|
||||
counter := &openAI403CounterResetStub{}
|
||||
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
||||
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
rateLimitService: rateLimitSvc,
|
||||
}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
svc.rateLimitService = rateLimitSvc
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{},
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage_reset_403",
|
||||
Model: "gpt-5.1",
|
||||
},
|
||||
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2001},
|
||||
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{777}, counter.resetCalls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -39,9 +41,18 @@ var cursorResponsesUnsupportedFields = []string{
|
||||
|
||||
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||||
// the response back to Chat Completions format. All account types (OAuth and API
|
||||
// Key) go through the Responses API conversion path since the upstream only
|
||||
// exposes the /v1/responses endpoint.
|
||||
// the response back to Chat Completions format.
|
||||
//
|
||||
// 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
|
||||
// 端点——这在 OAuth(ChatGPT 内部 API 仅支持 Responses)和官方 APIKey 账号上是
|
||||
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
|
||||
// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。
|
||||
//
|
||||
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI):
|
||||
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
|
||||
// 直转上游 /v1/chat/completions,不做协议转换
|
||||
// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
|
||||
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
|
||||
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
@@ -50,6 +61,12 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
promptCacheKey string,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
// 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。
|
||||
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
|
||||
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Chat Completions request
|
||||
@@ -189,7 +206,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@@ -348,59 +367,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@@ -459,6 +428,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -467,6 +437,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
@@ -496,54 +480,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
|
||||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai chat_completions stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, chunk := range chunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunks) > 0 {
|
||||
if len(chunks) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
|
||||
for _, chunk := range finalChunks {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Send [DONE] sentinel
|
||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||
c.Writer.Flush()
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
@@ -555,6 +551,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// Determine keepalive interval
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@@ -563,18 +562,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
|
||||
// No keepalive: fast synchronous path
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// With keepalive: goroutine + channel + select
|
||||
@@ -584,6 +590,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@@ -595,6 +603,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@@ -605,30 +614,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai chat_completions stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@@ -637,7 +675,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
437
backend/internal/service/openai_gateway_chat_completions_raw.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
|
||||
//
|
||||
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
|
||||
// (originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id),
|
||||
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
|
||||
// OpenAI 兼容上游会造成:
|
||||
// - 完全忽略(多数友好厂商)——隐性污染上游统计
|
||||
// - 400 "unknown parameter"(严格上游)——可见错误
|
||||
//
|
||||
// 这里仅放行通用 HTTP header;content-type / authorization / accept 由上下文
|
||||
// 显式设置,不依赖透传。
|
||||
//
|
||||
// 参见决策记录:
|
||||
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
|
||||
var openaiCCRawAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"user-agent": true,
|
||||
}
|
||||
|
||||
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
|
||||
// `{base_url}/v1/chat/completions`,**不**做 CC↔Responses 协议转换。
|
||||
//
|
||||
// 适用场景:account.platform=openai && account.type=apikey && 上游已被探测确认
|
||||
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
|
||||
//
|
||||
// 与 ForwardAsChatCompletions 的关键差异:
|
||||
//
|
||||
// - 不调用 apicompat.ChatCompletionsToResponses,body 仅做模型 ID 改写
|
||||
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
|
||||
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
|
||||
// - 非流式响应 JSON 直接透传,仅按需提取 usage
|
||||
// - 不应用 codex OAuth transform(APIKey 路径无 OAuth)
|
||||
// - 不注入 prompt_cache_key(OAuth 专属机制)
|
||||
//
|
||||
// 调用入口:openai_gateway_chat_completions.go::ForwardAsChatCompletions
|
||||
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
|
||||
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
defaultMappedModel string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse minimal fields needed for routing/billing
|
||||
originalModel := gjson.GetBytes(body, "model").String()
|
||||
if originalModel == "" {
|
||||
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return nil, fmt.Errorf("missing model in request")
|
||||
}
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
|
||||
serviceTier := extractOpenAIServiceTierFromBody(body)
|
||||
|
||||
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
|
||||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
|
||||
// 3. Rewrite model in body (no protocol conversion)
|
||||
upstreamBody := body
|
||||
if upstreamModel != originalModel {
|
||||
upstreamBody = ReplaceModelInBody(body, upstreamModel)
|
||||
}
|
||||
|
||||
// 4. Apply OpenAI fast policy on the CC body
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
upstreamBody = updatedBody
|
||||
if clientStream {
|
||||
var usageErr error
|
||||
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
|
||||
if usageErr != nil {
|
||||
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("billing_model", billingModel),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
zap.Bool("stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Build upstream request
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 6. Send request
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 7. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||
}
|
||||
|
||||
// 8. Forward response
|
||||
if clientStream {
|
||||
return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
|
||||
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage(包括
|
||||
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
|
||||
//
|
||||
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
|
||||
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage,
|
||||
// 让级联代理或下游计费系统也能拿到完整用量。
|
||||
func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if payload, ok := extractOpenAISSEDataLine(line); ok {
|
||||
trimmedPayload := strings.TrimSpace(payload)
|
||||
if trimmedPayload != "[DONE]" {
|
||||
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
if firstTokenMs == nil && !usageOnlyChunk {
|
||||
elapsed := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &elapsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
|
||||
zap.Error(werr),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai chat_completions raw: stream read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
|
||||
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
|
||||
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
|
||||
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
|
||||
if strings.TrimSpace(payload) == "" {
|
||||
return false
|
||||
}
|
||||
if !gjson.Get(payload, "usage").Exists() {
|
||||
return false
|
||||
}
|
||||
choices := gjson.Get(payload, "choices")
|
||||
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
|
||||
}
|
||||
|
||||
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
|
||||
// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时),
|
||||
// 但上游可能在多个 chunk 中重复——总是用最新值。
|
||||
func extractCCStreamUsage(payload string) *OpenAIUsage {
|
||||
usageResult := gjson.Get(payload, "usage")
|
||||
if !usageResult.Exists() || !usageResult.IsObject() {
|
||||
return nil
|
||||
}
|
||||
u := OpenAIUsage{
|
||||
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
|
||||
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
|
||||
}
|
||||
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
u.CacheReadInputTokens = int(cached.Int())
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
|
||||
func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||
}
|
||||
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||
}
|
||||
|
||||
var ccResp apicompat.ChatCompletionsResponse
|
||||
var usage OpenAIUsage
|
||||
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: ccResp.Usage.PromptTokens,
|
||||
OutputTokens: ccResp.Usage.CompletionTokens,
|
||||
}
|
||||
if ccResp.Usage.PromptTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
c.Writer.Header().Set("Content-Type", ct)
|
||||
} else {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
|
||||
//
|
||||
// - base 已是 /chat/completions:原样返回
|
||||
// - base 以 /v1 结尾:追加 /chat/completions
|
||||
// - 其他情况:追加 /v1/chat/completions
|
||||
//
|
||||
// 与 buildOpenAIResponsesURL 是姐妹函数。
|
||||
func buildOpenAIChatCompletionsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/chat/completions") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1") {
|
||||
return normalized + "/chat/completions"
|
||||
}
|
||||
return normalized + "/v1/chat/completions"
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
// 已是 /chat/completions:原样返回
|
||||
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
|
||||
// 以 /v1 结尾:追加 /chat/completions
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
|
||||
// 其他情况:追加 /v1/chat/completions
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
|
||||
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
|
||||
// 第三方上游常见形式
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
|
||||
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
|
||||
// 带空白字符
|
||||
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIChatCompletionsURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
|
||||
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
|
||||
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
want string
|
||||
}{
|
||||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
|
||||
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
|
||||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
|
||||
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
|
||||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
|
||||
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildOpenAIResponsesURL(tt.base)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
require.Contains(t, rec.Body.String(), `"usage"`)
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 17, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
|
||||
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader("toolong")),
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
|
||||
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||
|
||||
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
|
||||
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 101,
|
||||
Name: "raw-openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "http://upstream.example",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type openAIChatFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -73,3 +96,278 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_UnknownModelDoesNotUseDefaultMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt6","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_chat_unknown_model"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||
"",
|
||||
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||
defer func() {
|
||||
require.NoError(t, upstreamStream.Close())
|
||||
}()
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
|
||||
Body: upstreamStream,
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
type forwardResult struct {
|
||||
result *OpenAIForwardResult
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan forwardResult, 1)
|
||||
go func() {
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
resultCh <- forwardResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-resultCh:
|
||||
require.NoError(t, got.err)
|
||||
require.NotNil(t, got.result)
|
||||
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := "data: [DONE]\n\n"
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
require.Zero(t, result.Usage.InputTokens)
|
||||
require.Zero(t, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
@@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 6. Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResponse *apicompat.ResponsesResponse
|
||||
var usage OpenAIUsage
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
payload := line[6:]
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate delta content for fallback when terminal output is empty.
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil {
|
||||
finalResponse = event.Response
|
||||
if event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai messages buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finalResponse == nil {
|
||||
@@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
|
||||
switch strings.TrimSpace(eventType) {
|
||||
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAICompatDoneSentinelLine(line string) bool {
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
|
||||
resp *http.Response,
|
||||
logPrefix string,
|
||||
requestID string,
|
||||
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
|
||||
acc := apicompat.NewBufferedResponseAccumulator()
|
||||
var usage OpenAIUsage
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, usage, acc, errors.New("upstream response body is nil")
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var timeoutCh <-chan time.Time
|
||||
var timeoutTimer *time.Timer
|
||||
resetTimeout := func() {
|
||||
if streamInterval <= 0 {
|
||||
return
|
||||
}
|
||||
if timeoutTimer == nil {
|
||||
timeoutTimer = time.NewTimer(streamInterval)
|
||||
timeoutCh = timeoutTimer.C
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timeoutTimer.Reset(streamInterval)
|
||||
}
|
||||
stopTimeout := func() {
|
||||
if timeoutTimer == nil {
|
||||
return
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
resetTimeout()
|
||||
defer stopTimeout()
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case events <- scanEvent{line: scanner.Text()}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
select {
|
||||
case events <- scanEvent{err: err}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
resetTimeout()
|
||||
if ev.err != nil {
|
||||
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.L().Warn(logPrefix+": read error",
|
||||
zap.Error(ev.err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
return nil, usage, acc, ev.err
|
||||
}
|
||||
|
||||
if isOpenAICompatDoneSentinelLine(ev.line) {
|
||||
return nil, usage, acc, nil
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(ev.line)
|
||||
if !ok || payload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event apicompat.ResponsesStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn(logPrefix+": failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
acc.ProcessEvent(&event)
|
||||
|
||||
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
|
||||
if event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
return event.Response, usage, acc, nil
|
||||
}
|
||||
|
||||
case <-timeoutCh:
|
||||
_ = resp.Body.Close()
|
||||
logger.L().Warn(logPrefix+": data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||
// converts each to Anthropic SSE events, and writes them to the client.
|
||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||
@@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
// resultWithUsage builds the final result snapshot.
|
||||
resultWithUsage := func() *OpenAIForwardResult {
|
||||
return &OpenAIForwardResult{
|
||||
@@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||
// Returns (clientDisconnected bool).
|
||||
processDataLine := func(payload string) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
@@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
OutputTokens: event.Response.Usage.OutputTokens,
|
||||
}
|
||||
if event.Response.Usage.InputTokensDetails != nil {
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||
}
|
||||
|
||||
// Convert to Anthropic events
|
||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("openai messages stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true
|
||||
if !clientDisconnected {
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
if len(events) > 0 && !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
return isTerminalEvent
|
||||
}
|
||||
|
||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Info("openai messages stream: client disconnected during final flush",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
@@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
)
|
||||
}
|
||||
}
|
||||
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
// ── Determine keepalive interval ──
|
||||
keepaliveInterval := time.Duration(0)
|
||||
@@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||
if keepaliveInterval <= 0 {
|
||||
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
}
|
||||
handleScanErr(scanner.Err())
|
||||
return finalizeStream()
|
||||
if err := scanner.Err(); err != nil {
|
||||
handleScanErr(err)
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
return missingTerminalErr()
|
||||
}
|
||||
|
||||
// ── With keepalive: goroutine + channel + select ──
|
||||
@@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
@@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
@@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
for {
|
||||
@@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// Upstream closed
|
||||
return finalizeStream()
|
||||
return missingTerminalErr()
|
||||
}
|
||||
if ev.err != nil {
|
||||
handleScanErr(ev.err)
|
||||
return finalizeStream()
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
line := ev.line
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
if isOpenAICompatDoneSentinelLine(line) {
|
||||
return missingTerminalErr()
|
||||
}
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if processDataLine(line[6:]) {
|
||||
return resultWithUsage(), nil
|
||||
if processDataLine(payload) {
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
case <-keepaliveTicker.C:
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.L().Warn("openai messages stream: data interval timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("model", originalModel),
|
||||
zap.Duration("interval", streamInterval),
|
||||
)
|
||||
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
@@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return resultWithUsage(), nil
|
||||
clientDisconnected = true
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
@@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
|
||||
if usage == nil {
|
||||
return OpenAIUsage{}
|
||||
}
|
||||
result := OpenAIUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
}
|
||||
if usage.InputTokensDetails != nil {
|
||||
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -186,6 +186,56 @@ func max(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_zero_usage",
|
||||
Usage: OpenAIUsage{},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||
User: &User{ID: 2000},
|
||||
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
|
||||
require.Zero(t, usageRepo.lastLog.InputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.OutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
|
||||
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
|
||||
require.Zero(t, usageRepo.lastLog.ImageCount)
|
||||
require.Zero(t, usageRepo.lastLog.InputCost)
|
||||
require.Zero(t, usageRepo.lastLog.OutputCost)
|
||||
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||
groupID := int64(11)
|
||||
groupRate := 1.4
|
||||
@@ -956,9 +1006,8 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingMode
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
// When channel did NOT map the model (ChannelMappedModel == OriginalModel),
|
||||
// billing should use result.BillingModel (the actual model used after group
|
||||
// DefaultMappedModel resolution), not the unmapped original model.
|
||||
// 渠道未发生模型映射时,应使用 result.BillingModel 中记录的实际上游计费模型,
|
||||
// 而不是未映射的原始请求模型。
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
|
||||
@@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
@@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
|
||||
}
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
||||
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
|
||||
@@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
var usage OpenAIUsage
|
||||
imageCount := parsed.N
|
||||
var firstTokenMs *int
|
||||
if parsed.Stream {
|
||||
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
usage := OpenAIUsage{}
|
||||
imageCount := 0
|
||||
var firstTokenMs *int
|
||||
var fallbackBody bytes.Buffer
|
||||
fallbackBytes := int64(0)
|
||||
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
seenSSEData := false
|
||||
fallbackTooLarge := false
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
@@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
|
||||
if data != "" && data != "[DONE]" {
|
||||
seenSSEData = true
|
||||
fallbackBody.Reset()
|
||||
fallbackBytes = 0
|
||||
dataBytes := []byte(data)
|
||||
mergeOpenAIUsage(&usage, dataBytes)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
} else if !seenSSEData && !fallbackTooLarge {
|
||||
fallbackBytes += int64(len(line))
|
||||
if fallbackBytes <= fallbackLimit {
|
||||
_, _ = fallbackBody.Write(line)
|
||||
} else {
|
||||
fallbackTooLarge = true
|
||||
fallbackBody.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||
}
|
||||
}
|
||||
if !seenSSEData && fallbackBody.Len() > 0 {
|
||||
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||
if len(body) > 0 {
|
||||
mergeOpenAIUsage(&usage, body)
|
||||
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
|
||||
imageCount = count
|
||||
}
|
||||
}
|
||||
}
|
||||
return usage, imageCount, firstTokenMs, nil
|
||||
}
|
||||
|
||||
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
|
||||
return count
|
||||
}
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return 0
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
|
||||
return count
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
|
||||
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
|
||||
return 0
|
||||
}
|
||||
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||
if dst == nil {
|
||||
return
|
||||
|
||||
@@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_stream_json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 12, result.Usage.InputTokens)
|
||||
require.Equal(t, 21, result.Usage.OutputTokens)
|
||||
require.Equal(t, 9, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 10, result.Usage.InputTokens)
|
||||
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||
|
||||
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -2,44 +2,24 @@ package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
// resolveOpenAIForwardModel 解析 OpenAI 兼容转发使用的模型。
|
||||
// defaultMappedModel 只服务于 /v1/messages 的 Claude 系列显式调度映射,
|
||||
// 不作为普通 OpenAI 请求的未知模型兜底。
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
if defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
|
||||
if !matched && defaultMappedModel != "" && claudeMessagesDispatchFamily(requestedModel) != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
func isExplicitCodexModel(model string) bool {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
parts := strings.Split(model, "/")
|
||||
model = parts[len(parts)-1]
|
||||
}
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if getNormalizedCodexModel(model) != "" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(model, "-openai-compact") {
|
||||
base := strings.TrimSuffix(model, "-openai-compact")
|
||||
return getNormalizedCodexModel(base) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveOpenAICompactForwardModel determines the compact-only upstream model
|
||||
// for /responses/compact requests. It never affects normal /responses traffic.
|
||||
// When no compact-specific mapping matches, the input model is returned as-is.
|
||||
|
||||
@@ -11,7 +11,7 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "falls back to group default when account has no mapping",
|
||||
name: "uses messages dispatch default for claude model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
@@ -19,6 +19,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "does not fall back to group default for invalid gpt model",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt6",
|
||||
defaultMappedModel: "gpt-5.4",
|
||||
expectedModel: "gpt6",
|
||||
},
|
||||
{
|
||||
name: "preserves explicit gpt-5.4 instead of group default",
|
||||
account: &Account{
|
||||
@@ -119,14 +128,14 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||
if withoutDefault != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
|
||||
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
|
||||
if withoutDefault != "claude-opus-4-6" {
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withoutDefault, "claude-opus-4-6")
|
||||
}
|
||||
|
||||
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
|
||||
if withDefault != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,6 +214,10 @@ func TestNormalizeCodexModel(t *testing.T) {
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-image-2": "gpt-image-2",
|
||||
"gpt-5.4-nano": "gpt-5.4-nano",
|
||||
"gpt-5.4-nano-high": "gpt-5.4-nano",
|
||||
"gpt6": "gpt6",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
@@ -222,9 +235,21 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "oauth keeps codex normalization behavior",
|
||||
name: "oauth preserves unknown non codex model",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gemini-3-flash-preview",
|
||||
want: "gemini-3-flash-preview",
|
||||
},
|
||||
{
|
||||
name: "oauth preserves invalid gpt model",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gpt6",
|
||||
want: "gpt6",
|
||||
},
|
||||
{
|
||||
name: "oauth normalizes known codex alias",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gpt-5.4-high",
|
||||
want: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
|
||||
@@ -48,6 +48,49 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ResponsesUnknownModelDoesNotFallbackToGPT54(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
originalBody := []byte(`{"model":"gpt6","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(originalBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_unknown_model"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"invalid_request_error","message":"model not found"}}`)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://chatgpt.com/backend-api/codex/responses", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "gpt6", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.NotEqual(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.True(t, rec.Code >= http.StatusBadRequest)
|
||||
}
|
||||
|
||||
type openAIPassthroughFailoverRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCalls []time.Time
|
||||
@@ -307,6 +350,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
|
||||
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
@@ -405,6 +494,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
cancel()
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n"))),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
|
||||
|
||||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||||
type OpenAIWSIngressHooks struct {
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
// InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
|
||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||
InitialRequestModel string
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSLogValue(value string) string {
|
||||
@@ -1379,10 +1382,12 @@ func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
if signals.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
// If the client already sent tool-call context or item_reference anchors,
|
||||
// treat this as a full replay / self-contained continuation payload rather
|
||||
// than downgrading it into an inferred delta continuation.
|
||||
if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs {
|
||||
// If the client already sent the actual tool-call context, treat this as
|
||||
// a full replay / self-contained continuation payload rather than
|
||||
// downgrading it into an inferred delta continuation. item_reference alone
|
||||
// is not enough on the store=false WS path: it still needs a valid prior
|
||||
// response anchor so upstream can resolve the referenced function_call.
|
||||
if signals.HasToolCallContext {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||||
|
||||
@@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "high", *result.ReasoningEffort)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
@@ -1488,7 +1490,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) {
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
@@ -1619,7 +1621,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id")
|
||||
require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
|
||||
|
||||
@@ -303,12 +303,12 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_item_reference_already_covers_all_call_ids",
|
||||
name: "infer_when_only_item_reference_covers_call_ids",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_function_call_output_missing_call_id",
|
||||
|
||||
@@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
type openAIWSPassthroughUsageMeta struct {
|
||||
serviceTier atomic.Pointer[string]
|
||||
reasoningEffort atomic.Pointer[string]
|
||||
|
||||
// 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。
|
||||
sessionRequestModel string
|
||||
}
|
||||
|
||||
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
|
||||
meta := &openAIWSPassthroughUsageMeta{
|
||||
sessionRequestModel: strings.TrimSpace(initialRequestModel),
|
||||
}
|
||||
if meta.sessionRequestModel == "" {
|
||||
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
|
||||
m.sessionRequestModel = model
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
|
||||
if m == nil {
|
||||
return openAIWSPassthroughRequestModelForFrame(payload)
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
|
||||
return model
|
||||
}
|
||||
return m.sessionRequestModel
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
@@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// silently passed through, defeating the policy on every frame after
|
||||
// the first.
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||
initialRequestModel := ""
|
||||
if hooks != nil {
|
||||
initialRequestModel = hooks.InitialRequestModel
|
||||
}
|
||||
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
|
||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||
if policyErr != nil {
|
||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||
@@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
}
|
||||
firstClientMessage = updatedFirst
|
||||
|
||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
|
||||
// usage 上报:filter
|
||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||
@@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 service_tier。
|
||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||
// 可直接 Store/Load 而无需额外封装。
|
||||
var requestServiceTierPtr atomic.Pointer[string]
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
@@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
usageMeta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
|
||||
// Per-frame model first; if the client omits "model" on a
|
||||
// follow-up frame (legal in Realtime), fall back to the
|
||||
// session-level model captured from the first frame so the
|
||||
@@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||
// 多轮 passthrough usage:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 usageMeta,使用
|
||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||
// - 非 response.create 帧(response.cancel /
|
||||
// conversation.item.create / session.update 等)不携带
|
||||
// per-response service_tier,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||
// per-response metadata,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,usage metadata 应保持
|
||||
// 上一轮值。
|
||||
// - policyErr != nil:异常路径,保持上一轮值。
|
||||
// - 不带 service_tier 的 response.create 会让
|
||||
@@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// service_tier 时按 default 处理,billing 应如实反映。
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
},
|
||||
@@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
@@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
return nil
|
||||
}
|
||||
|
||||
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
|
||||
sourceOrderID := o.ID
|
||||
rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID)
|
||||
if err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
|
||||
@@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second
|
||||
|
||||
// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL)
|
||||
type cachedGatewayForwardingSettings struct {
|
||||
fingerprintUnification bool
|
||||
metadataPassthrough bool
|
||||
cchSigning bool
|
||||
expiresAt int64 // unix nano
|
||||
fingerprintUnification bool
|
||||
metadataPassthrough bool
|
||||
cchSigning bool
|
||||
anthropicCacheTTL1hInjection bool
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings
|
||||
@@ -1245,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
|
||||
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
|
||||
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
|
||||
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||
@@ -1305,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
||||
})
|
||||
gatewayForwardingSF.Forget("gateway_forwarding")
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: settings.EnableFingerprintUnification,
|
||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
cchSigning: settings.EnableCCHSigning,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
fingerprintUnification: settings.EnableFingerprintUnification,
|
||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
cchSigning: settings.EnableCCHSigning,
|
||||
anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
})
|
||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||
@@ -1415,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
|
||||
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
|
||||
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
|
||||
type gatewayForwardingSettingsResult struct {
|
||||
fp, mp, cch, cacheTTL1h bool
|
||||
}
|
||||
|
||||
func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
|
||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
|
||||
return gatewayForwardingSettingsResult{
|
||||
fp: cached.fingerprintUnification,
|
||||
mp: cached.metadataPassthrough,
|
||||
cch: cached.cchSigning,
|
||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||
}
|
||||
}
|
||||
}
|
||||
type gwfResult struct {
|
||||
fp, mp, cch bool
|
||||
}
|
||||
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
|
||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
|
||||
return gatewayForwardingSettingsResult{
|
||||
fp: cached.fingerprintUnification,
|
||||
mp: cached.metadataPassthrough,
|
||||
cch: cached.cchSigning,
|
||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
|
||||
@@ -1439,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
|
||||
SettingKeyEnableFingerprintUnification,
|
||||
SettingKeyEnableMetadataPassthrough,
|
||||
SettingKeyEnableCCHSigning,
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("failed to get gateway forwarding settings", "error", err)
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: true,
|
||||
metadataPassthrough: false,
|
||||
cchSigning: false,
|
||||
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
||||
fingerprintUnification: true,
|
||||
metadataPassthrough: false,
|
||||
cchSigning: false,
|
||||
anthropicCacheTTL1hInjection: false,
|
||||
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
||||
})
|
||||
return gwfResult{true, false, false}, nil
|
||||
return gatewayForwardingSettingsResult{fp: true}, nil
|
||||
}
|
||||
fp := true
|
||||
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
|
||||
@@ -1456,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
|
||||
}
|
||||
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
cch := values[SettingKeyEnableCCHSigning] == "true"
|
||||
cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: fp,
|
||||
metadataPassthrough: mp,
|
||||
cchSigning: cch,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
fingerprintUnification: fp,
|
||||
metadataPassthrough: mp,
|
||||
cchSigning: cch,
|
||||
anthropicCacheTTL1hInjection: cacheTTL1h,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
})
|
||||
return gwfResult{fp, mp, cch}, nil
|
||||
return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
|
||||
})
|
||||
if r, ok := val.(gwfResult); ok {
|
||||
return r.fp, r.mp, r.cch
|
||||
if r, ok := val.(gatewayForwardingSettingsResult); ok {
|
||||
return r
|
||||
}
|
||||
return true, false, false // fail-open defaults
|
||||
return gatewayForwardingSettingsResult{fp: true}
|
||||
}
|
||||
|
||||
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
|
||||
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
|
||||
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
|
||||
result := s.getGatewayForwardingSettingsCached(ctx)
|
||||
return result.fp, result.mp, result.cch
|
||||
}
|
||||
|
||||
// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。
|
||||
func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool {
|
||||
return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
@@ -1880,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyMaxClaudeCodeVersion: "",
|
||||
|
||||
// 分组隔离(默认不允许未分组 Key 调度)
|
||||
SettingKeyAllowUngroupedKeyScheduling: "false",
|
||||
SettingPaymentVisibleMethodAlipaySource: "",
|
||||
SettingPaymentVisibleMethodWxpaySource: "",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "false",
|
||||
openAIAdvancedSchedulerSettingKey: "false",
|
||||
SettingKeyAllowUngroupedKeyScheduling: "false",
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
||||
SettingPaymentVisibleMethodAlipaySource: "",
|
||||
SettingPaymentVisibleMethodWxpaySource: "",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "false",
|
||||
openAIAdvancedSchedulerSettingKey: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -2228,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||
result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||
|
||||
// Web search emulation: quick enabled check from the JSON config
|
||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||
|
||||
@@ -149,9 +149,10 @@ type SystemSettings struct {
|
||||
BackendModeEnabled bool
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||
|
||||
85
backend/migrations/134_affiliate_ledger_audit_snapshots.sql
Normal file
85
backend/migrations/134_affiliate_ledger_audit_snapshots.sql
Normal file
@@ -0,0 +1,85 @@
|
||||
-- 邀请返利流水补充订单关联和转余额快照。
|
||||
-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL,避免把当前状态误展示为历史状态。
|
||||
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL;
|
||||
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL;
|
||||
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL;
|
||||
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL;
|
||||
|
||||
COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL';
|
||||
COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL';
|
||||
COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL';
|
||||
COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL';
|
||||
COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id
|
||||
ON user_affiliate_ledger(source_order_id)
|
||||
WHERE source_order_id IS NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup
|
||||
ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at)
|
||||
WHERE action = 'accrue';
|
||||
|
||||
-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。
|
||||
-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。
|
||||
WITH rebate_audits AS (
|
||||
SELECT po.id AS order_id,
|
||||
po.user_id AS invitee_user_id,
|
||||
invitee_aff.inviter_id,
|
||||
rebate_detail.rebate_amount,
|
||||
pal.created_at AS audit_created_at
|
||||
FROM payment_audit_logs pal
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT substring(
|
||||
pal.detail
|
||||
FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)'
|
||||
)::numeric AS rebate_amount
|
||||
) rebate_detail
|
||||
JOIN payment_orders po ON po.id::text = pal.order_id
|
||||
JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id
|
||||
WHERE pal.action = 'AFFILIATE_REBATE_APPLIED'
|
||||
AND rebate_detail.rebate_amount IS NOT NULL
|
||||
),
|
||||
ranked_matches AS (
|
||||
SELECT ual.id AS ledger_id,
|
||||
ra.order_id,
|
||||
COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count,
|
||||
COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY ual.id
|
||||
ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id
|
||||
) AS ledger_rank
|
||||
FROM rebate_audits ra
|
||||
JOIN user_affiliate_ledger ual
|
||||
ON ual.action = 'accrue'
|
||||
AND ual.source_order_id IS NULL
|
||||
AND ual.user_id = ra.inviter_id
|
||||
AND ual.source_user_id = ra.invitee_user_id
|
||||
AND ABS(ual.amount - ra.rebate_amount) < 0.00000001
|
||||
AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes'
|
||||
AND ra.audit_created_at + INTERVAL '10 minutes'
|
||||
)
|
||||
UPDATE user_affiliate_ledger ual
|
||||
SET source_order_id = ranked_matches.order_id,
|
||||
updated_at = NOW()
|
||||
FROM ranked_matches
|
||||
WHERE ual.id = ranked_matches.ledger_id
|
||||
AND ranked_matches.order_match_count = 1
|
||||
AND ranked_matches.ledger_match_count = 1
|
||||
AND ranked_matches.ledger_rank = 1
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_affiliate_ledger existing
|
||||
WHERE existing.source_order_id = ranked_matches.order_id
|
||||
AND existing.action = 'accrue'
|
||||
);
|
||||
@@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
|
||||
require.Contains(t, sql, "oidc_connect_enabled")
|
||||
require.Contains(t, sql, "'false'")
|
||||
}
|
||||
|
||||
func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) {
|
||||
content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql")
|
||||
require.NoError(t, err)
|
||||
|
||||
sql := string(content)
|
||||
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT")
|
||||
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)")
|
||||
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)")
|
||||
require.Contains(t, sql, "substring(")
|
||||
require.Contains(t, sql, `"rebateAmount"`)
|
||||
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count")
|
||||
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
|
||||
require.NotContains(t, sql, "detail::jsonb")
|
||||
}
|
||||
|
||||
@@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams {
|
||||
search?: string
|
||||
}
|
||||
|
||||
export interface ListAffiliateRecordsParams {
|
||||
page?: number
|
||||
page_size?: number
|
||||
search?: string
|
||||
start_at?: string
|
||||
end_at?: string
|
||||
sort_by?: string
|
||||
sort_order?: 'asc' | 'desc'
|
||||
timezone?: string
|
||||
}
|
||||
|
||||
export interface AffiliateInviteRecord {
|
||||
inviter_id: number
|
||||
inviter_email: string
|
||||
inviter_username: string
|
||||
invitee_id: number
|
||||
invitee_email: string
|
||||
invitee_username: string
|
||||
aff_code: string
|
||||
total_rebate: number
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface AffiliateRebateRecord {
|
||||
order_id: number
|
||||
out_trade_no: string
|
||||
inviter_id: number
|
||||
inviter_email: string
|
||||
inviter_username: string
|
||||
invitee_id: number
|
||||
invitee_email: string
|
||||
invitee_username: string
|
||||
order_amount: number
|
||||
pay_amount: number
|
||||
rebate_amount: number
|
||||
payment_type: string
|
||||
order_status: string
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface AffiliateTransferRecord {
|
||||
ledger_id: number
|
||||
user_id: number
|
||||
user_email: string
|
||||
username: string
|
||||
amount: number
|
||||
balance_after?: number | null
|
||||
available_quota_after?: number | null
|
||||
frozen_quota_after?: number | null
|
||||
history_quota_after?: number | null
|
||||
snapshot_available: boolean
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface AffiliateUserOverview {
|
||||
user_id: number
|
||||
email: string
|
||||
username: string
|
||||
aff_code: string
|
||||
rebate_rate_percent: number
|
||||
invited_count: number
|
||||
rebated_invitee_count: number
|
||||
available_quota: number
|
||||
history_quota: number
|
||||
}
|
||||
|
||||
export interface UpdateAffiliateUserRequest {
|
||||
aff_code?: string
|
||||
aff_rebate_rate_percent?: number | null
|
||||
@@ -97,12 +163,68 @@ export async function batchSetRate(
|
||||
return data
|
||||
}
|
||||
|
||||
function recordParams(params: ListAffiliateRecordsParams = {}) {
|
||||
return {
|
||||
page: params.page ?? 1,
|
||||
page_size: params.page_size ?? 20,
|
||||
search: params.search ?? '',
|
||||
start_at: params.start_at || undefined,
|
||||
end_at: params.end_at || undefined,
|
||||
sort_by: params.sort_by || undefined,
|
||||
sort_order: params.sort_order || undefined,
|
||||
timezone: params.timezone || undefined,
|
||||
}
|
||||
}
|
||||
|
||||
export async function listInviteRecords(
|
||||
params: ListAffiliateRecordsParams = {},
|
||||
): Promise<PaginatedResponse<AffiliateInviteRecord>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<AffiliateInviteRecord>>(
|
||||
'/admin/affiliates/invites',
|
||||
{ params: recordParams(params) },
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listRebateRecords(
|
||||
params: ListAffiliateRecordsParams = {},
|
||||
): Promise<PaginatedResponse<AffiliateRebateRecord>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<AffiliateRebateRecord>>(
|
||||
'/admin/affiliates/rebates',
|
||||
{ params: recordParams(params) },
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listTransferRecords(
|
||||
params: ListAffiliateRecordsParams = {},
|
||||
): Promise<PaginatedResponse<AffiliateTransferRecord>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<AffiliateTransferRecord>>(
|
||||
'/admin/affiliates/transfers',
|
||||
{ params: recordParams(params) },
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getUserOverview(
|
||||
userId: number,
|
||||
): Promise<AffiliateUserOverview> {
|
||||
const { data } = await apiClient.get<AffiliateUserOverview>(
|
||||
`/admin/affiliates/users/${userId}/overview`,
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export const affiliatesAPI = {
|
||||
listUsers,
|
||||
lookupUsers,
|
||||
updateUserSettings,
|
||||
clearUserSettings,
|
||||
batchSetRate,
|
||||
listInviteRecords,
|
||||
listRebateRecords,
|
||||
listTransferRecords,
|
||||
getUserOverview,
|
||||
}
|
||||
|
||||
export default affiliatesAPI
|
||||
|
||||
@@ -439,6 +439,7 @@ export interface SystemSettings {
|
||||
enable_fingerprint_unification: boolean;
|
||||
enable_metadata_passthrough: boolean;
|
||||
enable_cch_signing: boolean;
|
||||
enable_anthropic_cache_ttl_1h_injection: boolean;
|
||||
web_search_emulation_enabled?: boolean;
|
||||
|
||||
// Payment configuration
|
||||
@@ -609,6 +610,7 @@ export interface UpdateSettingsRequest {
|
||||
enable_fingerprint_unification?: boolean;
|
||||
enable_metadata_passthrough?: boolean;
|
||||
enable_cch_signing?: boolean;
|
||||
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
||||
// Payment configuration
|
||||
payment_enabled?: boolean;
|
||||
payment_min_amount?: number;
|
||||
|
||||
@@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse<BalanceHistory
|
||||
* @param id - User ID
|
||||
* @param page - Page number
|
||||
* @param pageSize - Items per page
|
||||
* @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
* @param type - Optional type filter (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
|
||||
* @returns Paginated balance history with total_recharged
|
||||
*/
|
||||
export async function getUserBalanceHistory(
|
||||
|
||||
@@ -779,6 +779,110 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI Compact mode -->
|
||||
<div v-if="allOpenAIPassthroughCapable" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div class="flex-1 pr-4">
|
||||
<label
|
||||
id="bulk-edit-openai-compact-mode-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-openai-compact-mode-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.openai.compactMode') }}
|
||||
</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.compactModeDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
v-model="enableOpenAICompactMode"
|
||||
id="bulk-edit-openai-compact-mode-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-openai-compact-mode"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
id="bulk-edit-openai-compact-mode"
|
||||
:class="!enableOpenAICompactMode && 'pointer-events-none opacity-50'"
|
||||
>
|
||||
<Select
|
||||
v-model="openAICompactMode"
|
||||
data-testid="bulk-edit-openai-compact-mode-select"
|
||||
:options="openAICompactModeOptions"
|
||||
aria-labelledby="bulk-edit-openai-compact-mode-label"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI Compact model mapping -->
|
||||
<div v-if="allOpenAIPassthroughCapable" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div class="flex-1 pr-4">
|
||||
<label
|
||||
id="bulk-edit-openai-compact-model-mapping-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-openai-compact-model-mapping-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.openai.compactModelMapping') }}
|
||||
</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
v-model="enableOpenAICompactModelMapping"
|
||||
id="bulk-edit-openai-compact-model-mapping-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-openai-compact-model-mapping"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
id="bulk-edit-openai-compact-model-mapping"
|
||||
:class="!enableOpenAICompactModelMapping && 'pointer-events-none opacity-50'"
|
||||
>
|
||||
<div v-if="openAICompactModelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in openAICompactModelMappings"
|
||||
:key="index"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.fromModel')"
|
||||
data-testid="bulk-edit-openai-compact-model-mapping-input"
|
||||
/>
|
||||
<span class="text-gray-400">→</span>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.toModel')"
|
||||
data-testid="bulk-edit-openai-compact-model-mapping-input"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
@click="removeOpenAICompactModelMapping(index)"
|
||||
>
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
|
||||
data-testid="bulk-edit-openai-compact-model-mapping-add"
|
||||
@click="addOpenAICompactModelMapping"
|
||||
>
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
|
||||
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@@ -989,7 +1093,7 @@ import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType } from '@/types'
|
||||
import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform, AccountType, OpenAICompactMode } from '@/types'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
@@ -1115,6 +1219,8 @@ const enableOpenAIPassthrough = ref(false)
|
||||
const enableOpenAIWSMode = ref(false)
|
||||
const enableOpenAIAPIKeyWSMode = ref(false)
|
||||
const enableCodexCLIOnly = ref(false)
|
||||
const enableOpenAICompactMode = ref(false)
|
||||
const enableOpenAICompactModelMapping = ref(false)
|
||||
const enableRpmLimit = ref(false)
|
||||
|
||||
// State - field values
|
||||
@@ -1140,6 +1246,8 @@ const openaiPassthroughEnabled = ref(false)
|
||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const codexCLIOnlyEnabled = ref(false)
|
||||
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
||||
const openAICompactModelMappings = ref<ModelMapping[]>([])
|
||||
const rpmLimitEnabled = ref(false)
|
||||
const bulkBaseRpm = ref<number | null>(null)
|
||||
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
|
||||
@@ -1178,6 +1286,11 @@ const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
|
||||
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||
])
|
||||
const openAICompactModeOptions = computed(() => [
|
||||
{ value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
|
||||
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
|
||||
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
|
||||
])
|
||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
@@ -1194,6 +1307,14 @@ const removeModelMapping = (index: number) => {
|
||||
modelMappings.value.splice(index, 1)
|
||||
}
|
||||
|
||||
const addOpenAICompactModelMapping = () => {
|
||||
openAICompactModelMappings.value.push({ from: '', to: '' })
|
||||
}
|
||||
|
||||
const removeOpenAICompactModelMapping = (index: number) => {
|
||||
openAICompactModelMappings.value.splice(index, 1)
|
||||
}
|
||||
|
||||
const addPresetMapping = (from: string, to: string) => {
|
||||
const exists = modelMappings.value.some((m) => m.from === from)
|
||||
if (exists) {
|
||||
@@ -1262,6 +1383,10 @@ const buildModelMappingObject = (): Record<string, string> | null => {
|
||||
)
|
||||
}
|
||||
|
||||
const buildOpenAICompactModelMapping = (): Record<string, string> | null => {
|
||||
return buildModelMappingPayload('mapping', [], openAICompactModelMappings.value)
|
||||
}
|
||||
|
||||
const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
const updates: Record<string, unknown> = {}
|
||||
const credentials: Record<string, unknown> = {}
|
||||
@@ -1350,10 +1475,6 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
credentialsChanged = true
|
||||
}
|
||||
|
||||
if (credentialsChanged) {
|
||||
updates.credentials = credentials
|
||||
}
|
||||
|
||||
if (enableOpenAIWSMode.value) {
|
||||
const extra = ensureExtra()
|
||||
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
@@ -1375,6 +1496,16 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
extra.codex_cli_only = codexCLIOnlyEnabled.value
|
||||
}
|
||||
|
||||
if (enableOpenAICompactMode.value) {
|
||||
const extra = ensureExtra()
|
||||
extra.openai_compact_mode = openAICompactMode.value
|
||||
}
|
||||
|
||||
if (enableOpenAICompactModelMapping.value) {
|
||||
credentials.compact_model_mapping = buildOpenAICompactModelMapping() ?? {}
|
||||
credentialsChanged = true
|
||||
}
|
||||
|
||||
// RPM limit settings (写入 extra 字段)
|
||||
if (enableRpmLimit.value) {
|
||||
const extra = ensureExtra()
|
||||
@@ -1402,6 +1533,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge)
|
||||
}
|
||||
|
||||
if (credentialsChanged) {
|
||||
updates.credentials = credentials
|
||||
}
|
||||
|
||||
return Object.keys(updates).length > 0 ? updates : null
|
||||
}
|
||||
|
||||
@@ -1467,6 +1602,8 @@ const handleSubmit = async () => {
|
||||
enableOpenAIWSMode.value ||
|
||||
enableOpenAIAPIKeyWSMode.value ||
|
||||
enableCodexCLIOnly.value ||
|
||||
enableOpenAICompactMode.value ||
|
||||
enableOpenAICompactModelMapping.value ||
|
||||
enableRpmLimit.value ||
|
||||
userMsgQueueMode.value !== null
|
||||
|
||||
@@ -1567,6 +1704,8 @@ watch(
|
||||
enableOpenAIWSMode.value = false
|
||||
enableOpenAIAPIKeyWSMode.value = false
|
||||
enableCodexCLIOnly.value = false
|
||||
enableOpenAICompactMode.value = false
|
||||
enableOpenAICompactModelMapping.value = false
|
||||
enableRpmLimit.value = false
|
||||
|
||||
// Reset all values
|
||||
@@ -1588,6 +1727,8 @@ watch(
|
||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||
codexCLIOnlyEnabled.value = false
|
||||
openAICompactMode.value = 'auto'
|
||||
openAICompactModelMappings.value = []
|
||||
rpmLimitEnabled.value = false
|
||||
bulkBaseRpm.value = null
|
||||
bulkRpmStrategy.value = 'tiered'
|
||||
|
||||
@@ -217,6 +217,44 @@ describe('BulkEditAccountModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('筛选 OpenAI 账号批量编辑应提交 Compact 模式和专属模型映射', async () => {
|
||||
const wrapper = mountModal({
|
||||
accountIds: [],
|
||||
selectedPlatforms: [],
|
||||
selectedTypes: [],
|
||||
target: {
|
||||
mode: 'filtered',
|
||||
filters: { platform: 'openai' },
|
||||
previewCount: 12,
|
||||
selectedPlatforms: ['openai'],
|
||||
selectedTypes: ['oauth', 'apikey']
|
||||
}
|
||||
})
|
||||
|
||||
await wrapper.get('#bulk-edit-openai-compact-mode-enabled').setValue(true)
|
||||
await wrapper.get('[data-testid="bulk-edit-openai-compact-mode-select"]').setValue('force_on')
|
||||
await wrapper.get('#bulk-edit-openai-compact-model-mapping-enabled').setValue(true)
|
||||
await wrapper.get('[data-testid="bulk-edit-openai-compact-model-mapping-add"]').trigger('click')
|
||||
const inputs = wrapper.findAll('[data-testid="bulk-edit-openai-compact-model-mapping-input"]')
|
||||
await inputs[0].setValue('gpt-5.4')
|
||||
await inputs[1].setValue('gpt-5.4-openai-compact')
|
||||
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||
await flushPromises()
|
||||
|
||||
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
|
||||
filters: { platform: 'openai' },
|
||||
extra: {
|
||||
openai_compact_mode: 'force_on'
|
||||
},
|
||||
credentials: {
|
||||
compact_model_mapping: {
|
||||
'gpt-5.4': 'gpt-5.4-openai-compact'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
|
||||
const wrapper = mountModal({
|
||||
selectedPlatforms: ['openai'],
|
||||
|
||||
@@ -196,6 +196,7 @@ const totalPages = computed(() => Math.ceil(total.value / pageSize) || 1)
|
||||
const typeOptions = computed(() => [
|
||||
{ value: '', label: t('admin.users.allTypes') },
|
||||
{ value: 'balance', label: t('admin.users.typeBalance') },
|
||||
{ value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') },
|
||||
{ value: 'admin_balance', label: t('admin.users.typeAdminBalance') },
|
||||
{ value: 'concurrency', label: t('admin.users.typeConcurrency') },
|
||||
{ value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') },
|
||||
@@ -235,7 +236,7 @@ const loadHistory = async (page: number) => {
|
||||
const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency'
|
||||
|
||||
// Helper: check if balance type (includes admin_balance)
|
||||
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance'
|
||||
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance'
|
||||
|
||||
// Helper: check if subscription type
|
||||
const isSubscriptionType = (type: string) => type === 'subscription'
|
||||
@@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => {
|
||||
switch (item.type) {
|
||||
case 'balance':
|
||||
return t('redeem.balanceAddedRedeem')
|
||||
case 'affiliate_balance':
|
||||
return t('redeem.balanceAddedAffiliate')
|
||||
case 'admin_balance':
|
||||
return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin')
|
||||
case 'concurrency':
|
||||
|
||||
@@ -123,6 +123,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import Select from './Select.vue'
|
||||
import { getConfiguredTablePageSizeOptions, normalizeTablePageSize } from '@/utils/tablePreferences'
|
||||
import { setPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -224,6 +225,7 @@ const goToPage = (newPage: number) => {
|
||||
const handlePageSizeChange = (value: string | number | boolean | null) => {
|
||||
if (value === null || typeof value === 'boolean') return
|
||||
const newPageSize = normalizeTablePageSize(typeof value === 'string' ? parseInt(value, 10) : value)
|
||||
setPersistedPageSize(newPageSize)
|
||||
emit('update:pageSize', newPageSize)
|
||||
}
|
||||
|
||||
|
||||
@@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => {
|
||||
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
|
||||
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
|
||||
{
|
||||
path: '/admin/affiliates',
|
||||
label: t('nav.affiliateManagement'),
|
||||
icon: UsersIcon,
|
||||
hideInSimpleMode: true,
|
||||
expandOnly: true,
|
||||
featureFlag: flagAffiliate,
|
||||
children: [
|
||||
{ path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon },
|
||||
{ path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon },
|
||||
{ path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon },
|
||||
],
|
||||
},
|
||||
{
|
||||
path: '/admin/orders',
|
||||
label: t('nav.orderManagement'),
|
||||
|
||||
@@ -1,9 +1,29 @@
|
||||
import { getConfiguredTableDefaultPageSize, normalizeTablePageSize } from '@/utils/tablePreferences'
|
||||
|
||||
/**
|
||||
* 读取当前系统配置的表格默认每页条数。
|
||||
* 不再使用本地持久化缓存,所有页面统一以通用表格设置为准。
|
||||
*/
|
||||
const STORAGE_KEY = 'table-page-size'
|
||||
|
||||
export function getPersistedPageSize(fallback = getConfiguredTableDefaultPageSize()): number {
|
||||
if (typeof window !== 'undefined') {
|
||||
try {
|
||||
const stored = window.localStorage.getItem(STORAGE_KEY)
|
||||
if (stored !== null) {
|
||||
const parsed = Number(stored)
|
||||
if (Number.isFinite(parsed)) {
|
||||
return normalizeTablePageSize(parsed)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to read persisted page size:', error)
|
||||
}
|
||||
}
|
||||
return normalizeTablePageSize(getConfiguredTableDefaultPageSize() || fallback)
|
||||
}
|
||||
|
||||
export function setPersistedPageSize(size: number): void {
|
||||
if (typeof window === 'undefined') return
|
||||
try {
|
||||
window.localStorage.setItem(STORAGE_KEY, String(size))
|
||||
} catch (error) {
|
||||
console.warn('Failed to persist page size:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ref, reactive, onUnmounted, toRaw } from 'vue'
|
||||
import { useDebounceFn } from '@vueuse/core'
|
||||
import type { BasePaginationResponse, FetchOptions } from '@/types'
|
||||
import { getPersistedPageSize } from './usePersistedPageSize'
|
||||
import { getPersistedPageSize, setPersistedPageSize } from './usePersistedPageSize'
|
||||
|
||||
interface PaginationState {
|
||||
page: number
|
||||
@@ -88,6 +88,7 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
||||
const handlePageSizeChange = (size: number) => {
|
||||
pagination.page_size = size
|
||||
pagination.page = 1
|
||||
setPersistedPageSize(size)
|
||||
load()
|
||||
}
|
||||
|
||||
|
||||
@@ -347,6 +347,10 @@ export default {
|
||||
usage: 'Usage',
|
||||
redeem: 'Redeem',
|
||||
affiliate: 'Affiliate Rebates',
|
||||
affiliateManagement: 'Affiliate Rebates',
|
||||
affiliateInviteRecords: 'Invite Records',
|
||||
affiliateRebateRecords: 'Rebate Records',
|
||||
affiliateTransferRecords: 'Transfer Records',
|
||||
profile: 'Profile',
|
||||
users: 'Users',
|
||||
groups: 'Groups',
|
||||
@@ -1046,6 +1050,7 @@ export default {
|
||||
recentActivity: 'Recent Activity',
|
||||
historyWillAppear: 'Your redemption history will appear here',
|
||||
balanceAddedRedeem: 'Balance Added (Redeem)',
|
||||
balanceAddedAffiliate: 'Balance Added (Affiliate Transfer)',
|
||||
balanceAddedAdmin: 'Balance Added (Admin)',
|
||||
balanceDeductedAdmin: 'Balance Deducted (Admin)',
|
||||
concurrencyAddedRedeem: 'Concurrency Added (Redeem)',
|
||||
@@ -1635,6 +1640,49 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
affiliates: {
|
||||
invitesDescription: 'View site-wide inviter and invitee relationships',
|
||||
rebatesDescription: 'View recharge orders that generated affiliate rebates',
|
||||
transfersDescription: 'View affiliate quota transfers into account balance',
|
||||
errors: {
|
||||
loadFailed: 'Failed to load affiliate records'
|
||||
},
|
||||
records: {
|
||||
search: 'Search',
|
||||
searchPlaceholder: 'Email, username, user ID, or order number',
|
||||
startAt: 'Start date',
|
||||
endAt: 'End date',
|
||||
inviter: 'Inviter',
|
||||
invitee: 'Invitee',
|
||||
user: 'User',
|
||||
affCode: 'Invite Code',
|
||||
order: 'Order',
|
||||
totalRebate: 'Total Rebate',
|
||||
orderAmount: 'Top-up Amount',
|
||||
payAmount: 'Paid Amount',
|
||||
rebateAmount: 'Rebate Amount',
|
||||
paymentType: 'Payment Method',
|
||||
orderStatus: 'Order Status',
|
||||
transferAmount: 'Transfer Amount',
|
||||
balanceAfter: 'Balance After',
|
||||
availableQuotaAfter: 'Available After',
|
||||
frozenQuotaAfter: 'Frozen After',
|
||||
historyQuotaAfter: 'Historical Rebate After',
|
||||
invitedAt: 'Invited At',
|
||||
rebatedAt: 'Rebated At',
|
||||
transferredAt: 'Transferred At'
|
||||
},
|
||||
overview: {
|
||||
title: 'Affiliate User Overview',
|
||||
affCode: 'Invite Code',
|
||||
rebateRate: 'Rebate Rate',
|
||||
invitedCount: 'Invited Users',
|
||||
rebatedInviteeCount: 'Rebated Invitees',
|
||||
availableQuota: 'Available Quota',
|
||||
historyQuota: 'Historical Rebate'
|
||||
}
|
||||
},
|
||||
|
||||
// Users
|
||||
users: {
|
||||
title: 'User Management',
|
||||
@@ -1787,6 +1835,7 @@ export default {
|
||||
noBalanceHistory: 'No records found for this user',
|
||||
allTypes: 'All Types',
|
||||
typeBalance: 'Balance (Redeem)',
|
||||
typeAffiliateBalance: 'Balance (Affiliate Transfer)',
|
||||
typeAdminBalance: 'Balance (Admin)',
|
||||
typeConcurrency: 'Concurrency (Redeem)',
|
||||
typeAdminConcurrency: 'Concurrency (Admin)',
|
||||
@@ -5019,6 +5068,8 @@ export default {
|
||||
metadataPassthroughHint: 'Pass through client\'s original metadata.user_id without rewriting. May improve upstream cache hit rates.',
|
||||
cchSigning: 'CCH Signing',
|
||||
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
|
||||
anthropicCacheTTL1hInjection: 'Anthropic Cache TTL Injection',
|
||||
anthropicCacheTTL1hInjectionHint: 'When enabled, existing ephemeral cache_control blocks in Anthropic OAuth/Setup Token request bodies are forced to 1h; response usage is billed back as 5m by default, with account-level TTL billing override taking priority.',
|
||||
},
|
||||
webSearchEmulation: {
|
||||
title: 'Web Search Emulation',
|
||||
|
||||
@@ -347,6 +347,10 @@ export default {
|
||||
usage: '使用记录',
|
||||
redeem: '兑换',
|
||||
affiliate: '邀请返利',
|
||||
affiliateManagement: '邀请返利',
|
||||
affiliateInviteRecords: '邀请记录',
|
||||
affiliateRebateRecords: '返利记录',
|
||||
affiliateTransferRecords: '提取记录',
|
||||
profile: '个人资料',
|
||||
users: '用户管理',
|
||||
groups: '分组管理',
|
||||
@@ -1050,6 +1054,7 @@ export default {
|
||||
recentActivity: '最近活动',
|
||||
historyWillAppear: '您的兑换历史将显示在这里',
|
||||
balanceAddedRedeem: '余额充值(兑换)',
|
||||
balanceAddedAffiliate: '余额充值(返利转入)',
|
||||
balanceAddedAdmin: '余额充值(管理员)',
|
||||
balanceDeductedAdmin: '余额扣除(管理员)',
|
||||
concurrencyAddedRedeem: '并发增加(兑换)',
|
||||
@@ -1656,6 +1661,49 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
affiliates: {
|
||||
invitesDescription: '查看全站邀请关系和被邀请用户累计返利',
|
||||
rebatesDescription: '查看每一笔产生返利的充值订单',
|
||||
transfersDescription: '查看返利额度转入账户余额的提取流水',
|
||||
errors: {
|
||||
loadFailed: '加载邀请返利记录失败'
|
||||
},
|
||||
records: {
|
||||
search: '搜索',
|
||||
searchPlaceholder: '邮箱、用户名、用户 ID、订单号',
|
||||
startAt: '开始日期',
|
||||
endAt: '结束日期',
|
||||
inviter: '邀请人',
|
||||
invitee: '被邀请人',
|
||||
user: '用户',
|
||||
affCode: '邀请码',
|
||||
order: '订单',
|
||||
totalRebate: '累计返利',
|
||||
orderAmount: '充值金额',
|
||||
payAmount: '支付金额',
|
||||
rebateAmount: '返利金额',
|
||||
paymentType: '支付方式',
|
||||
orderStatus: '订单状态',
|
||||
transferAmount: '提取金额',
|
||||
balanceAfter: '提取后余额',
|
||||
availableQuotaAfter: '提取后可提',
|
||||
frozenQuotaAfter: '提取后冻结',
|
||||
historyQuotaAfter: '提取后历史返利',
|
||||
invitedAt: '邀请时间',
|
||||
rebatedAt: '返利时间',
|
||||
transferredAt: '提取时间'
|
||||
},
|
||||
overview: {
|
||||
title: '用户返利概览',
|
||||
affCode: '邀请码',
|
||||
rebateRate: '返利比例',
|
||||
invitedCount: '邀请人数',
|
||||
rebatedInviteeCount: '已产生返利人数',
|
||||
availableQuota: '可提余额',
|
||||
historyQuota: '历史返利'
|
||||
}
|
||||
},
|
||||
|
||||
// Users Management
|
||||
users: {
|
||||
title: '用户管理',
|
||||
@@ -1844,6 +1892,7 @@ export default {
|
||||
noBalanceHistory: '暂无变动记录',
|
||||
allTypes: '全部类型',
|
||||
typeBalance: '余额(兑换码)',
|
||||
typeAffiliateBalance: '余额(返利转入)',
|
||||
typeAdminBalance: '余额(管理员调整)',
|
||||
typeConcurrency: '并发(兑换码)',
|
||||
typeAdminConcurrency: '并发(管理员调整)',
|
||||
@@ -5178,6 +5227,8 @@ export default {
|
||||
metadataPassthroughHint: '透传客户端原始 metadata.user_id,不进行重写。可能提高上游缓存命中率。',
|
||||
cchSigning: 'CCH 签名',
|
||||
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
|
||||
anthropicCacheTTL1hInjection: 'Anthropic 缓存 TTL 注入',
|
||||
anthropicCacheTTL1hInjectionHint: '开启后,对 Anthropic OAuth/Setup Token 请求体中已有的 ephemeral 缓存块强制写入 1h;响应 usage 默认按 5m 回写计费,账号级 TTL 计费设置优先。',
|
||||
},
|
||||
webSearchEmulation: {
|
||||
title: 'Web Search 模拟',
|
||||
|
||||
@@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [
|
||||
descriptionKey: 'admin.usage.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/affiliates',
|
||||
redirect: '/admin/affiliates/invites'
|
||||
},
|
||||
{
|
||||
path: '/admin/affiliates/invites',
|
||||
name: 'AdminAffiliateInvites',
|
||||
component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Affiliate Invite Records',
|
||||
titleKey: 'nav.affiliateInviteRecords',
|
||||
descriptionKey: 'admin.affiliates.invitesDescription'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/affiliates/rebates',
|
||||
name: 'AdminAffiliateRebates',
|
||||
component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Affiliate Rebate Records',
|
||||
titleKey: 'nav.affiliateRebateRecords',
|
||||
descriptionKey: 'admin.affiliates.rebatesDescription'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/affiliates/transfers',
|
||||
name: 'AdminAffiliateTransfers',
|
||||
component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Affiliate Transfer Records',
|
||||
titleKey: 'nav.affiliateTransferRecords',
|
||||
descriptionKey: 'admin.affiliates.transfersDescription'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
// ==================== Payment Admin Routes ====================
|
||||
|
||||
@@ -3057,6 +3057,31 @@
|
||||
</div>
|
||||
<Toggle v-model="form.enable_cch_signing" />
|
||||
</div>
|
||||
|
||||
<!-- Anthropic Cache TTL 1h Injection -->
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label
|
||||
class="text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
{{
|
||||
t(
|
||||
"admin.settings.gatewayForwarding.anthropicCacheTTL1hInjection",
|
||||
)
|
||||
}}
|
||||
</label>
|
||||
<p class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t(
|
||||
"admin.settings.gatewayForwarding.anthropicCacheTTL1hInjectionHint",
|
||||
)
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle
|
||||
v-model="form.enable_anthropic_cache_ttl_1h_injection"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Web Search Emulation -->
|
||||
@@ -5810,6 +5835,7 @@ const form = reactive<SettingsForm>({
|
||||
enable_fingerprint_unification: true,
|
||||
enable_metadata_passthrough: false,
|
||||
enable_cch_signing: false,
|
||||
enable_anthropic_cache_ttl_1h_injection: false,
|
||||
// Balance & quota notification
|
||||
balance_low_notify_enabled: false,
|
||||
balance_low_notify_threshold: 0,
|
||||
@@ -6718,6 +6744,8 @@ async function saveSettings() {
|
||||
enable_fingerprint_unification: form.enable_fingerprint_unification,
|
||||
enable_metadata_passthrough: form.enable_metadata_passthrough,
|
||||
enable_cch_signing: form.enable_cch_signing,
|
||||
enable_anthropic_cache_ttl_1h_injection:
|
||||
form.enable_anthropic_cache_ttl_1h_injection,
|
||||
// Payment configuration
|
||||
payment_enabled: form.payment_enabled,
|
||||
payment_min_amount: Number(form.payment_min_amount) || 0,
|
||||
|
||||
@@ -362,6 +362,7 @@ const baseSettingsResponse = {
|
||||
enable_fingerprint_unification: true,
|
||||
enable_metadata_passthrough: false,
|
||||
enable_cch_signing: false,
|
||||
enable_anthropic_cache_ttl_1h_injection: false,
|
||||
payment_enabled: true,
|
||||
payment_min_amount: 1,
|
||||
payment_max_amount: 10000,
|
||||
@@ -567,6 +568,26 @@ describe("admin SettingsView payment visible method controls", () => {
|
||||
expect(payload).not.toHaveProperty("payment_visible_method_wxpay_enabled");
|
||||
});
|
||||
|
||||
it("submits Anthropic cache TTL injection gateway setting", async () => {
|
||||
getSettings.mockResolvedValueOnce({
|
||||
...baseSettingsResponse,
|
||||
enable_anthropic_cache_ttl_1h_injection: true,
|
||||
});
|
||||
|
||||
const wrapper = mountView();
|
||||
|
||||
await flushPromises();
|
||||
await wrapper.find("form").trigger("submit.prevent");
|
||||
await flushPromises();
|
||||
|
||||
expect(updateSettings).toHaveBeenCalledTimes(1);
|
||||
expect(updateSettings).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
enable_anthropic_cache_ttl_1h_injection: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("updates provider enablement immediately and reloads providers", async () => {
|
||||
const provider = {
|
||||
id: 7,
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
<template>
|
||||
<AdminAffiliateRecordsTable type="invites" />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||
</script>
|
||||
@@ -0,0 +1,7 @@
|
||||
<template>
|
||||
<AdminAffiliateRecordsTable type="rebates" />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||
</script>
|
||||
@@ -0,0 +1,407 @@
|
||||
<template>
|
||||
<AppLayout>
|
||||
<TablePageLayout>
|
||||
<template #filters>
|
||||
<div class="flex flex-wrap items-center gap-3">
|
||||
<div class="relative w-full md:w-80">
|
||||
<Icon name="search" size="md" class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400" />
|
||||
<input v-model="filters.search" type="text" class="input pl-10" :placeholder="t('admin.affiliates.records.searchPlaceholder')" @input="debounceLoad" />
|
||||
</div>
|
||||
<input v-model="filters.start_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.startAt')" @change="reloadFromFirstPage" />
|
||||
<input v-model="filters.end_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.endAt')" @change="reloadFromFirstPage" />
|
||||
<button class="btn btn-secondary px-2 md:px-3" :disabled="loading" :title="t('common.refresh')" @click="loadRecords">
|
||||
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template #table>
|
||||
<DataTable
|
||||
:columns="columns"
|
||||
:data="records"
|
||||
:loading="loading"
|
||||
:server-side-sort="true"
|
||||
default-sort-key="created_at"
|
||||
default-sort-order="desc"
|
||||
:sort-storage-key="sortStorageKey"
|
||||
@sort="handleSort"
|
||||
>
|
||||
<template #cell-inviter="{ row }">
|
||||
<UserCell
|
||||
:id="row.inviter_id"
|
||||
:email="row.inviter_email"
|
||||
:username="row.inviter_username"
|
||||
:clickable="props.type !== 'transfers'"
|
||||
@open="openUserOverview"
|
||||
/>
|
||||
</template>
|
||||
<template #cell-invitee="{ row }">
|
||||
<UserCell
|
||||
:id="row.invitee_id"
|
||||
:email="row.invitee_email"
|
||||
:username="row.invitee_username"
|
||||
:clickable="props.type !== 'transfers'"
|
||||
@open="openUserOverview"
|
||||
/>
|
||||
</template>
|
||||
<template #cell-user="{ row }">
|
||||
<UserCell
|
||||
:id="row.user_id"
|
||||
:email="row.user_email"
|
||||
:username="row.username"
|
||||
:clickable="true"
|
||||
@open="openUserOverview"
|
||||
/>
|
||||
</template>
|
||||
<template #cell-aff_code="{ row }">
|
||||
<span class="font-mono text-sm text-gray-700 dark:text-gray-300">{{ row.aff_code || '-' }}</span>
|
||||
</template>
|
||||
<template #cell-order="{ row }">
|
||||
<div class="space-y-0.5">
|
||||
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ row.order_id }}</div>
|
||||
<div class="max-w-56 truncate text-sm text-gray-500 dark:text-dark-400">{{ row.out_trade_no }}</div>
|
||||
</div>
|
||||
</template>
|
||||
<template #cell-payment_type="{ row }">
|
||||
{{ t('payment.methods.' + row.payment_type, row.payment_type || '-') }}
|
||||
</template>
|
||||
<template #cell-order_status="{ row }">
|
||||
<OrderStatusBadge :status="row.order_status" />
|
||||
</template>
|
||||
<template #cell-total_rebate="{ row }">
|
||||
<AmountText :value="row.total_rebate" />
|
||||
</template>
|
||||
<template #cell-order_amount="{ row }">
|
||||
<AmountText :value="row.order_amount" />
|
||||
</template>
|
||||
<template #cell-pay_amount="{ row }">
|
||||
<span class="text-sm text-gray-900 dark:text-white">¥{{ formatAmount(row.pay_amount) }}</span>
|
||||
</template>
|
||||
<template #cell-rebate_amount="{ row }">
|
||||
<AmountText :value="row.rebate_amount" strong />
|
||||
</template>
|
||||
<template #cell-amount="{ row }">
|
||||
<AmountText :value="row.amount" strong />
|
||||
</template>
|
||||
<template #cell-balance_after="{ row }">
|
||||
<NullableAmountText :value="row.balance_after" />
|
||||
</template>
|
||||
<template #cell-available_quota_after="{ row }">
|
||||
<NullableAmountText :value="row.available_quota_after" />
|
||||
</template>
|
||||
<template #cell-frozen_quota_after="{ row }">
|
||||
<NullableAmountText :value="row.frozen_quota_after" />
|
||||
</template>
|
||||
<template #cell-history_quota_after="{ row }">
|
||||
<NullableAmountText :value="row.history_quota_after" />
|
||||
</template>
|
||||
<template #cell-created_at="{ row }">
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ formatDateTime(row.created_at) }}</span>
|
||||
</template>
|
||||
</DataTable>
|
||||
</template>
|
||||
|
||||
<template #pagination>
|
||||
<Pagination
|
||||
v-if="pagination.total > 0"
|
||||
:page="pagination.page"
|
||||
:total="pagination.total"
|
||||
:page-size="pagination.page_size"
|
||||
@update:page="handlePageChange"
|
||||
@update:pageSize="handlePageSizeChange"
|
||||
/>
|
||||
</template>
|
||||
</TablePageLayout>
|
||||
|
||||
<BaseDialog
|
||||
:show="overviewDialog"
|
||||
:title="t('admin.affiliates.overview.title')"
|
||||
width="normal"
|
||||
@close="overviewDialog = false"
|
||||
>
|
||||
<div v-if="overviewLoading" class="flex justify-center py-8">
|
||||
<div class="h-6 w-6 animate-spin rounded-full border-2 border-primary-500 border-t-transparent"></div>
|
||||
</div>
|
||||
<div v-else-if="selectedOverview" class="space-y-4">
|
||||
<div class="rounded-lg border border-gray-100 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800">
|
||||
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ selectedOverview.user_id }}</div>
|
||||
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">{{ selectedOverview.email || '-' }}</div>
|
||||
<div class="mt-0.5 text-sm text-gray-500 dark:text-dark-400">{{ selectedOverview.username || '-' }}</div>
|
||||
</div>
|
||||
<div class="grid gap-3 sm:grid-cols-2">
|
||||
<OverviewStat :label="t('admin.affiliates.overview.affCode')" :value="selectedOverview.aff_code || '-'" mono />
|
||||
<OverviewStat :label="t('admin.affiliates.overview.rebateRate')" :value="formatPercent(selectedOverview.rebate_rate_percent)" />
|
||||
<OverviewStat :label="t('admin.affiliates.overview.invitedCount')" :value="String(selectedOverview.invited_count)" />
|
||||
<OverviewStat :label="t('admin.affiliates.overview.rebatedInviteeCount')" :value="String(selectedOverview.rebated_invitee_count)" />
|
||||
<OverviewStat :label="t('admin.affiliates.overview.availableQuota')" :value="'$' + formatAmount(selectedOverview.available_quota)" />
|
||||
<OverviewStat :label="t('admin.affiliates.overview.historyQuota')" :value="'$' + formatAmount(selectedOverview.history_quota)" />
|
||||
</div>
|
||||
</div>
|
||||
</BaseDialog>
|
||||
</AppLayout>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, defineComponent, h, onMounted, reactive, ref, type PropType } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue'
|
||||
import type { Column } from '@/components/common/types'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { affiliatesAPI, type AffiliateInviteRecord, type AffiliateRebateRecord, type AffiliateTransferRecord, type AffiliateUserOverview, type ListAffiliateRecordsParams } from '@/api/admin/affiliates'
|
||||
import type { PaginatedResponse } from '@/types'
|
||||
import { extractI18nErrorMessage } from '@/utils/apiError'
|
||||
import { formatDateTime as formatDisplayDateTime } from '@/utils/format'
|
||||
|
||||
type RecordType = 'invites' | 'rebates' | 'transfers'
|
||||
type AffiliateRecord = AffiliateInviteRecord | AffiliateRebateRecord | AffiliateTransferRecord
|
||||
|
||||
const props = defineProps<{
|
||||
type: RecordType
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
const loading = ref(false)
|
||||
const records = ref<AffiliateRecord[]>([])
|
||||
const filters = reactive({ search: '', start_at: '', end_at: '' })
|
||||
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
|
||||
const overviewDialog = ref(false)
|
||||
const overviewLoading = ref(false)
|
||||
const selectedOverview = ref<AffiliateUserOverview | null>(null)
|
||||
let debounceTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const columns = computed<Column[]>(() => {
|
||||
if (props.type === 'invites') {
|
||||
return [
|
||||
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
|
||||
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
|
||||
{ key: 'aff_code', label: t('admin.affiliates.records.affCode'), sortable: true },
|
||||
{ key: 'total_rebate', label: t('admin.affiliates.records.totalRebate'), sortable: true },
|
||||
{ key: 'created_at', label: t('admin.affiliates.records.invitedAt'), sortable: true },
|
||||
]
|
||||
}
|
||||
if (props.type === 'rebates') {
|
||||
return [
|
||||
{ key: 'order', label: t('admin.affiliates.records.order'), sortable: true },
|
||||
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
|
||||
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
|
||||
{ key: 'order_amount', label: t('admin.affiliates.records.orderAmount'), sortable: true },
|
||||
{ key: 'pay_amount', label: t('admin.affiliates.records.payAmount'), sortable: true },
|
||||
{ key: 'rebate_amount', label: t('admin.affiliates.records.rebateAmount') },
|
||||
{ key: 'payment_type', label: t('admin.affiliates.records.paymentType'), sortable: true },
|
||||
{ key: 'order_status', label: t('admin.affiliates.records.orderStatus'), sortable: true },
|
||||
{ key: 'created_at', label: t('admin.affiliates.records.rebatedAt'), sortable: true },
|
||||
]
|
||||
}
|
||||
return [
|
||||
{ key: 'user', label: t('admin.affiliates.records.user'), sortable: true },
|
||||
{ key: 'amount', label: t('admin.affiliates.records.transferAmount'), sortable: true },
|
||||
{ key: 'balance_after', label: t('admin.affiliates.records.balanceAfter'), sortable: true },
|
||||
{ key: 'available_quota_after', label: t('admin.affiliates.records.availableQuotaAfter'), sortable: true },
|
||||
{ key: 'frozen_quota_after', label: t('admin.affiliates.records.frozenQuotaAfter'), sortable: true },
|
||||
{ key: 'history_quota_after', label: t('admin.affiliates.records.historyQuotaAfter'), sortable: true },
|
||||
{ key: 'created_at', label: t('admin.affiliates.records.transferredAt'), sortable: true },
|
||||
]
|
||||
})
|
||||
|
||||
const sortStorageKey = computed(() => `admin-affiliate-${props.type}-table-sort`)
|
||||
|
||||
function loadInitialSortState(): { sort_by: string; sort_order: 'asc' | 'desc' } {
|
||||
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
|
||||
try {
|
||||
const raw = localStorage.getItem(sortStorageKey.value)
|
||||
if (!raw) return fallback
|
||||
const parsed = JSON.parse(raw) as { key?: string; order?: string }
|
||||
const key = typeof parsed.key === 'string' ? parsed.key : ''
|
||||
if (!columns.value.some((column) => column.key === key && column.sortable)) return fallback
|
||||
return {
|
||||
sort_by: key,
|
||||
sort_order: parsed.order === 'asc' ? 'asc' : 'desc',
|
||||
}
|
||||
} catch {
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
const sortState = reactive(loadInitialSortState())
|
||||
|
||||
function userTimezone(): string {
|
||||
try {
|
||||
return Intl.DateTimeFormat().resolvedOptions().timeZone
|
||||
} catch {
|
||||
return 'UTC'
|
||||
}
|
||||
}
|
||||
|
||||
function buildParams(): ListAffiliateRecordsParams {
|
||||
return {
|
||||
page: pagination.page,
|
||||
page_size: pagination.page_size,
|
||||
search: filters.search.trim() || undefined,
|
||||
start_at: filters.start_at || undefined,
|
||||
end_at: filters.end_at || undefined,
|
||||
sort_by: sortState.sort_by,
|
||||
sort_order: sortState.sort_order,
|
||||
timezone: userTimezone(),
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchRecords(params: ListAffiliateRecordsParams): Promise<PaginatedResponse<AffiliateRecord>> {
|
||||
if (props.type === 'invites') {
|
||||
return affiliatesAPI.listInviteRecords(params)
|
||||
}
|
||||
if (props.type === 'rebates') {
|
||||
return affiliatesAPI.listRebateRecords(params)
|
||||
}
|
||||
return affiliatesAPI.listTransferRecords(params)
|
||||
}
|
||||
|
||||
async function loadRecords() {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await fetchRecords(buildParams())
|
||||
records.value = res.items || []
|
||||
pagination.total = res.total || 0
|
||||
} catch (error) {
|
||||
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function debounceLoad() {
|
||||
if (debounceTimer) clearTimeout(debounceTimer)
|
||||
debounceTimer = setTimeout(() => reloadFromFirstPage(), 300)
|
||||
}
|
||||
|
||||
function reloadFromFirstPage() {
|
||||
pagination.page = 1
|
||||
void loadRecords()
|
||||
}
|
||||
|
||||
function handlePageChange(page: number) {
|
||||
pagination.page = page
|
||||
void loadRecords()
|
||||
}
|
||||
|
||||
function handlePageSizeChange(size: number) {
|
||||
pagination.page_size = size
|
||||
pagination.page = 1
|
||||
void loadRecords()
|
||||
}
|
||||
|
||||
function handleSort(key: string, order: 'asc' | 'desc') {
|
||||
sortState.sort_by = key
|
||||
sortState.sort_order = order
|
||||
pagination.page = 1
|
||||
void loadRecords()
|
||||
}
|
||||
|
||||
function formatAmount(value: number | null | undefined): string {
|
||||
return Number(value || 0).toFixed(2)
|
||||
}
|
||||
|
||||
function formatPercent(value: number | null | undefined): string {
|
||||
const rounded = Math.round(Number(value || 0) * 100) / 100
|
||||
return `${Number.isInteger(rounded) ? rounded.toString() : rounded.toString()}%`
|
||||
}
|
||||
|
||||
function formatDateTime(value: string | null | undefined): string {
|
||||
return value ? formatDisplayDateTime(value) : '-'
|
||||
}
|
||||
|
||||
async function openUserOverview(userId: number) {
|
||||
if (!userId) return
|
||||
overviewDialog.value = true
|
||||
overviewLoading.value = true
|
||||
selectedOverview.value = null
|
||||
try {
|
||||
selectedOverview.value = await affiliatesAPI.getUserOverview(userId)
|
||||
} catch (error) {
|
||||
overviewDialog.value = false
|
||||
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
|
||||
} finally {
|
||||
overviewLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const UserCell = defineComponent({
|
||||
props: {
|
||||
id: { type: Number, required: true },
|
||||
email: { type: String, default: '' },
|
||||
username: { type: String, default: '' },
|
||||
clickable: { type: Boolean, default: false },
|
||||
},
|
||||
emits: ['open'],
|
||||
setup(cellProps, { emit }) {
|
||||
return () => h('div', { class: 'space-y-0.5' }, [
|
||||
h('div', { class: 'font-mono text-sm text-gray-900 dark:text-white' }, `#${cellProps.id}`),
|
||||
h(cellProps.clickable ? 'button' : 'div', {
|
||||
class: cellProps.clickable
|
||||
? 'max-w-56 truncate text-left text-sm font-medium text-primary-600 hover:text-primary-700 hover:underline dark:text-primary-400 dark:hover:text-primary-300'
|
||||
: 'max-w-56 truncate text-sm text-gray-700 dark:text-gray-300',
|
||||
type: cellProps.clickable ? 'button' : undefined,
|
||||
onClick: cellProps.clickable ? () => emit('open', cellProps.id) : undefined,
|
||||
}, cellProps.email || '-'),
|
||||
h('div', { class: 'max-w-56 truncate text-sm text-gray-500 dark:text-dark-400' }, cellProps.username || '-'),
|
||||
])
|
||||
},
|
||||
})
|
||||
|
||||
const AmountText = defineComponent({
|
||||
props: {
|
||||
value: { type: Number, default: 0 },
|
||||
strong: { type: Boolean, default: false },
|
||||
},
|
||||
setup(amountProps) {
|
||||
return () => h('span', {
|
||||
class: amountProps.strong
|
||||
? 'text-sm font-semibold text-emerald-600 dark:text-emerald-400'
|
||||
: 'text-sm text-gray-900 dark:text-white',
|
||||
}, `$${formatAmount(amountProps.value)}`)
|
||||
},
|
||||
})
|
||||
|
||||
const NullableAmountText = defineComponent({
|
||||
props: {
|
||||
value: { type: Number as PropType<number | null | undefined>, default: null },
|
||||
},
|
||||
setup(amountProps) {
|
||||
return () => {
|
||||
const value = amountProps.value
|
||||
if (value === null || value === undefined) {
|
||||
return h('span', { class: 'text-sm text-gray-400 dark:text-dark-500' }, '-')
|
||||
}
|
||||
return h(AmountText, { value })
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
const OverviewStat = defineComponent({
|
||||
props: {
|
||||
label: { type: String, required: true },
|
||||
value: { type: String, required: true },
|
||||
mono: { type: Boolean, default: false },
|
||||
},
|
||||
setup(statProps) {
|
||||
return () => h('div', { class: 'rounded-lg border border-gray-100 bg-white p-3 dark:border-dark-700 dark:bg-dark-900' }, [
|
||||
h('div', { class: 'text-sm text-gray-500 dark:text-dark-400' }, statProps.label),
|
||||
h('div', {
|
||||
class: statProps.mono
|
||||
? 'mt-1 font-mono text-base font-semibold text-gray-900 dark:text-white'
|
||||
: 'mt-1 text-base font-semibold text-gray-900 dark:text-white',
|
||||
}, statProps.value),
|
||||
])
|
||||
},
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
void loadRecords()
|
||||
})
|
||||
</script>
|
||||
@@ -0,0 +1,7 @@
|
||||
<template>
|
||||
<AdminAffiliateRecordsTable type="transfers" />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
|
||||
</script>
|
||||
Reference in New Issue
Block a user