mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-24 00:24:45 +08:00
Merge remote-tracking branch 'origin/main' into fix/enc_coot
# Conflicts: # backend/internal/service/openai_gateway_service.go
This commit is contained in:
@@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
@@ -338,6 +339,7 @@ type OpenAIGatewayService struct {
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -355,6 +357,7 @@ func NewOpenAIGatewayService(
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
@@ -2119,7 +2122,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2326,7 +2331,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2663,6 +2670,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
sawTerminalEvent := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -2682,6 +2690,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -2699,19 +2710,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
if sawTerminalEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
ctx.Err(),
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
@@ -2725,12 +2731,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
@@ -3264,6 +3271,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sawTerminalEvent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
@@ -3294,22 +3302,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
}
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||
if scanErr == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sawTerminalEvent {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
}
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||
@@ -3332,6 +3345,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
dataBytes := []byte(data)
|
||||
if openAIStreamEventIsTerminal(data) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@@ -3448,8 +3464,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return resultWithUsage(), nil
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -3547,11 +3562,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
|
||||
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
||||
if len(data) < 72 {
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(data, "type").String() != "response.completed" {
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4007,14 +4023,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -4080,11 +4097,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
@@ -4125,29 +4143,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user