mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-19 06:14:45 +08:00
Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312
This commit is contained in:
@@ -52,6 +52,8 @@ const (
|
||||
openAIWSRetryJitterRatioDefault = 0.2
|
||||
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
||||
codexCLIVersion = "0.104.0"
|
||||
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
|
||||
openAICodexSnapshotPersistMinInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-passthrough).
|
||||
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
|
||||
nonRetryableFastFallback atomic.Int64
|
||||
}
|
||||
|
||||
type accountWriteThrottle struct {
|
||||
minInterval time.Duration
|
||||
mu sync.Mutex
|
||||
lastByID map[int64]time.Time
|
||||
}
|
||||
|
||||
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
|
||||
return &accountWriteThrottle{
|
||||
minInterval: minInterval,
|
||||
lastByID: make(map[int64]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
|
||||
if t == nil || id <= 0 || t.minInterval <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
|
||||
return false
|
||||
}
|
||||
t.lastByID[id] = now
|
||||
|
||||
if len(t.lastByID) > 4096 {
|
||||
cutoff := now.Add(-4 * t.minInterval)
|
||||
for accountID, writtenAt := range t.lastByID {
|
||||
if writtenAt.Before(cutoff) {
|
||||
delete(t.lastByID, accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -290,6 +332,7 @@ type OpenAIGatewayService struct {
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
codexSnapshotThrottle *accountWriteThrottle
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@@ -332,17 +375,25 @@ func NewOpenAIGatewayService(
|
||||
nil,
|
||||
"service.openai_gateway",
|
||||
),
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
svc.logOpenAIWSModeBootstrap()
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||
if s != nil && s.codexSnapshotThrottle != nil {
|
||||
return s.codexSnapshotThrottle
|
||||
}
|
||||
return defaultOpenAICodexSnapshotPersistThrottle
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
@@ -1719,6 +1770,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified = true
|
||||
markPatchSet("model", normalizedModel)
|
||||
}
|
||||
|
||||
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
||||
// 确保高版本模型向低版本模型映射不报错
|
||||
if !SupportsVerbosity(normalizedModel) {
|
||||
if text, ok := reqBody["text"].(map[string]any); ok {
|
||||
delete(text, "verbosity")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||
@@ -2954,6 +3013,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// compatErrorWriter is the signature for format-specific error writers used by
|
||||
// the compat paths (Chat Completions and Anthropic Messages).
|
||||
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
|
||||
|
||||
// handleCompatErrorResponse is the shared non-failover error handler for the
|
||||
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
|
||||
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
|
||||
// tracking, secondary failover) but delegates the final error write to the
|
||||
// format-specific writer function.
|
||||
func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
writeError compatErrorWriter,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
// Apply error passthrough rules
|
||||
if status, errType, errMsg, matched := applyErrorPassthroughRule(
|
||||
c, account.Platform, resp.StatusCode, body,
|
||||
http.StatusBadGateway, "api_error", "Upstream request failed",
|
||||
); matched {
|
||||
writeError(c, status, errType, errMsg)
|
||||
if upstreamMsg == "" {
|
||||
upstreamMsg = errMsg
|
||||
}
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// Check custom error codes — if the account does not handle this status,
|
||||
// return a generic error without exposing upstream details.
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// Track rate limits and decide whether to trigger secondary failover.
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
)
|
||||
}
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: kind,
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: body,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// Map status code to error type and write response
|
||||
errType := "api_error"
|
||||
switch {
|
||||
case resp.StatusCode == 400:
|
||||
errType = "invalid_request_error"
|
||||
case resp.StatusCode == 404:
|
||||
errType = "not_found_error"
|
||||
case resp.StatusCode == 429:
|
||||
errType = "rate_limit_error"
|
||||
case resp.StatusCode >= 500:
|
||||
errType = "api_error"
|
||||
}
|
||||
|
||||
writeError(c, resp.StatusCode, errType, upstreamMsg)
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// openaiStreamingResult streaming response result
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
@@ -4071,11 +4244,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
|
||||
if !shouldPersistUpdates && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if len(updates) > 0 {
|
||||
if shouldPersistUpdates {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}
|
||||
if resetAt != nil {
|
||||
|
||||
Reference in New Issue
Block a user