mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。
后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
× 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
(所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。
HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
命中)、原生 Responses、Passthrough Responses 全部接入
applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
避免 chat/messages/native /responses/passthrough 因为 model 维度不同
造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
此前已具备同等行为)。
WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
=1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
filter 闭包内同时侦测 session.update / session.created 帧的 session.model
字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
/ WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
error: {type: "forbidden_error", code: "policy_violation", message}},
与 OpenAI codex 客户端 error event 解析兼容。
Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
omitempty 行为。
测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
pass 路径 fast 别名归一化回归 +
ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
error event 再收 close 1008 且上游 0 写)+
passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
passthrough session.update / session.created 旋转 capturedSessionModel
的 mid-session 绕过回归 +
passthrough billing post-filter ServiceTier 与 idempotent filter 回归。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
613 lines
20 KiB
Go
613 lines
20 KiB
Go
package service
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it
|
|
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
|
// the response back to Anthropic Messages format. This enables Claude Code
|
|
// clients to access OpenAI models through the standard /v1/messages endpoint.
|
|
func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|
ctx context.Context,
|
|
c *gin.Context,
|
|
account *Account,
|
|
body []byte,
|
|
promptCacheKey string,
|
|
defaultMappedModel string,
|
|
) (*OpenAIForwardResult, error) {
|
|
startTime := time.Now()
|
|
|
|
// 1. Parse Anthropic request
|
|
var anthropicReq apicompat.AnthropicRequest
|
|
if err := json.Unmarshal(body, &anthropicReq); err != nil {
|
|
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
|
}
|
|
originalModel := anthropicReq.Model
|
|
applyOpenAICompatModelNormalization(&anthropicReq)
|
|
normalizedModel := anthropicReq.Model
|
|
clientStream := anthropicReq.Stream // client's original stream preference
|
|
|
|
// 2. Convert Anthropic → Responses
|
|
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
|
|
}
|
|
|
|
// Upstream always uses streaming (upstream may not support sync mode).
|
|
// The client's original preference determines the response format.
|
|
responsesReq.Stream = true
|
|
isStream := true
|
|
|
|
// 2b. Handle BetaFastMode → service_tier: "priority"
|
|
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
|
|
responsesReq.ServiceTier = "priority"
|
|
}
|
|
|
|
// 3. Model mapping
|
|
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
|
|
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
|
responsesReq.Model = upstreamModel
|
|
|
|
logger.L().Debug("openai messages: model mapping applied",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.String("original_model", originalModel),
|
|
zap.String("normalized_model", normalizedModel),
|
|
zap.String("billing_model", billingModel),
|
|
zap.String("upstream_model", upstreamModel),
|
|
zap.Bool("stream", isStream),
|
|
)
|
|
|
|
// 4. Marshal Responses request body, then apply OAuth codex transform
|
|
responsesBody, err := json.Marshal(responsesReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal responses request: %w", err)
|
|
}
|
|
|
|
if account.Type == AccountTypeOAuth {
|
|
var reqBody map[string]any
|
|
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
|
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
|
}
|
|
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
|
forcedTemplateText := ""
|
|
if s.cfg != nil {
|
|
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
|
|
}
|
|
templateUpstreamModel := upstreamModel
|
|
if codexResult.NormalizedModel != "" {
|
|
templateUpstreamModel = codexResult.NormalizedModel
|
|
}
|
|
existingInstructions, _ := reqBody["instructions"].(string)
|
|
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
|
|
ExistingInstructions: strings.TrimSpace(existingInstructions),
|
|
OriginalModel: originalModel,
|
|
NormalizedModel: normalizedModel,
|
|
BillingModel: billingModel,
|
|
UpstreamModel: templateUpstreamModel,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
if codexResult.NormalizedModel != "" {
|
|
upstreamModel = codexResult.NormalizedModel
|
|
}
|
|
if codexResult.PromptCacheKey != "" {
|
|
promptCacheKey = codexResult.PromptCacheKey
|
|
} else if promptCacheKey != "" {
|
|
reqBody["prompt_cache_key"] = promptCacheKey
|
|
}
|
|
// OAuth codex transform forces stream=true upstream, so always use
|
|
// the streaming response handler regardless of what the client asked.
|
|
isStream = true
|
|
responsesBody, err = json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
|
}
|
|
}
|
|
|
|
// For API key accounts (including OpenAI-compatible upstream gateways),
|
|
// ensure promptCacheKey is also propagated via the request body so that
|
|
// upstreams using the Responses API can derive a stable session identifier
|
|
// from prompt_cache_key. This makes our Anthropic /v1/messages compatibility
|
|
// path behave more like a native Responses client.
|
|
if account.Type == AccountTypeAPIKey {
|
|
if trimmedKey := strings.TrimSpace(promptCacheKey); trimmedKey != "" {
|
|
var reqBody map[string]any
|
|
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
|
return nil, fmt.Errorf("unmarshal for prompt cache key injection: %w", err)
|
|
}
|
|
if existing, ok := reqBody["prompt_cache_key"].(string); !ok || strings.TrimSpace(existing) == "" {
|
|
reqBody["prompt_cache_key"] = trimmedKey
|
|
updated, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("remarshal after prompt cache key injection: %w", err)
|
|
}
|
|
responsesBody = updated
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
|
|
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
|
|
// on the body-level service_tier field (priority/flex).
|
|
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
|
if policyErr != nil {
|
|
var blocked *OpenAIFastBlockedError
|
|
if errors.As(policyErr, &blocked) {
|
|
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
|
|
}
|
|
return nil, policyErr
|
|
}
|
|
responsesBody = updatedBody
|
|
|
|
// 5. Get access token
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get access token: %w", err)
|
|
}
|
|
|
|
// 6. Build upstream request
|
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
|
}
|
|
|
|
// Override session_id with a deterministic UUID derived from the isolated
|
|
// session key, ensuring different API keys produce different upstream sessions.
|
|
if promptCacheKey != "" {
|
|
apiKeyID := getAPIKeyIDFromContext(c)
|
|
upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey)))
|
|
}
|
|
|
|
// 7. Send request
|
|
proxyURL := ""
|
|
if account.Proxy != nil {
|
|
proxyURL = account.Proxy.URL()
|
|
}
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|
if err != nil {
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: 0,
|
|
Kind: "request_error",
|
|
Message: safeErr,
|
|
})
|
|
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
// 8. Handle error response with failover
|
|
if resp.StatusCode >= 400 {
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
_ = resp.Body.Close()
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
|
upstreamDetail := ""
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
if maxBytes <= 0 {
|
|
maxBytes = 2048
|
|
}
|
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
}
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
Kind: "failover",
|
|
Message: upstreamMsg,
|
|
Detail: upstreamDetail,
|
|
})
|
|
if s.rateLimitService != nil {
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
}
|
|
return nil, &UpstreamFailoverError{
|
|
StatusCode: resp.StatusCode,
|
|
ResponseBody: respBody,
|
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
|
}
|
|
}
|
|
// Non-failover error: return Anthropic-formatted error to client
|
|
return s.handleAnthropicErrorResponse(resp, c, account)
|
|
}
|
|
|
|
// 9. Handle normal response
|
|
// Upstream is always streaming; choose response format based on client preference.
|
|
var result *OpenAIForwardResult
|
|
var handleErr error
|
|
if clientStream {
|
|
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
|
} else {
|
|
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
|
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
|
}
|
|
|
|
// Propagate ServiceTier and ReasoningEffort to result for billing
|
|
if handleErr == nil && result != nil {
|
|
if responsesReq.ServiceTier != "" {
|
|
st := responsesReq.ServiceTier
|
|
result.ServiceTier = &st
|
|
}
|
|
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
|
re := responsesReq.Reasoning.Effort
|
|
result.ReasoningEffort = &re
|
|
}
|
|
}
|
|
|
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
|
if handleErr == nil && account.Type == AccountTypeOAuth {
|
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
|
}
|
|
}
|
|
|
|
return result, handleErr
|
|
}
|
|
|
|
// handleAnthropicErrorResponse reads an upstream error and returns it in
|
|
// Anthropic error format.
|
|
func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
|
resp *http.Response,
|
|
c *gin.Context,
|
|
account *Account,
|
|
) (*OpenAIForwardResult, error) {
|
|
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
|
|
}
|
|
|
|
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
|
// the upstream streaming response, finds the terminal event (response.completed
|
|
// / response.incomplete / response.failed), converts the complete response to
|
|
// Anthropic Messages JSON format, and writes it to the client.
|
|
// This is used when the client requested stream=false but the upstream is always
|
|
// streaming.
|
|
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|
resp *http.Response,
|
|
c *gin.Context,
|
|
originalModel string,
|
|
billingModel string,
|
|
upstreamModel string,
|
|
startTime time.Time,
|
|
) (*OpenAIForwardResult, error) {
|
|
requestID := resp.Header.Get("x-request-id")
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
maxLineSize := defaultMaxLineSize
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
}
|
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
|
|
var finalResponse *apicompat.ResponsesResponse
|
|
var usage OpenAIUsage
|
|
acc := apicompat.NewBufferedResponseAccumulator()
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
continue
|
|
}
|
|
payload := line[6:]
|
|
|
|
var event apicompat.ResponsesStreamEvent
|
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
logger.L().Warn("openai messages buffered: failed to parse event",
|
|
zap.Error(err),
|
|
zap.String("request_id", requestID),
|
|
)
|
|
continue
|
|
}
|
|
|
|
// Accumulate delta content for fallback when terminal output is empty.
|
|
acc.ProcessEvent(&event)
|
|
|
|
// Terminal events carry the complete ResponsesResponse with output + usage.
|
|
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
|
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
event.Response != nil {
|
|
finalResponse = event.Response
|
|
if event.Response.Usage != nil {
|
|
usage = OpenAIUsage{
|
|
InputTokens: event.Response.Usage.InputTokens,
|
|
OutputTokens: event.Response.Usage.OutputTokens,
|
|
}
|
|
if event.Response.Usage.InputTokensDetails != nil {
|
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
logger.L().Warn("openai messages buffered: read error",
|
|
zap.Error(err),
|
|
zap.String("request_id", requestID),
|
|
)
|
|
}
|
|
}
|
|
|
|
if finalResponse == nil {
|
|
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
|
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
|
}
|
|
|
|
// When the terminal event has an empty output array, reconstruct from
|
|
// accumulated delta events so the client receives the full content.
|
|
acc.SupplementResponseOutput(finalResponse)
|
|
|
|
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
|
|
|
|
if s.responseHeaderFilter != nil {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
}
|
|
c.JSON(http.StatusOK, anthropicResp)
|
|
|
|
return &OpenAIForwardResult{
|
|
RequestID: requestID,
|
|
Usage: usage,
|
|
Model: originalModel,
|
|
BillingModel: billingModel,
|
|
UpstreamModel: upstreamModel,
|
|
Stream: false,
|
|
Duration: time.Since(startTime),
|
|
}, nil
|
|
}
|
|
|
|
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
|
// converts each to Anthropic SSE events, and writes them to the client.
|
|
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
|
// pattern to send Anthropic ping events during periods of upstream silence,
|
|
// preventing proxy/client timeout disconnections.
|
|
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|
resp *http.Response,
|
|
c *gin.Context,
|
|
originalModel string,
|
|
billingModel string,
|
|
upstreamModel string,
|
|
startTime time.Time,
|
|
) (*OpenAIForwardResult, error) {
|
|
requestID := resp.Header.Get("x-request-id")
|
|
|
|
if s.responseHeaderFilter != nil {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
|
|
state := apicompat.NewResponsesEventToAnthropicState()
|
|
state.Model = originalModel
|
|
var usage OpenAIUsage
|
|
var firstTokenMs *int
|
|
firstChunk := true
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
maxLineSize := defaultMaxLineSize
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
}
|
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
|
|
// resultWithUsage builds the final result snapshot.
|
|
resultWithUsage := func() *OpenAIForwardResult {
|
|
return &OpenAIForwardResult{
|
|
RequestID: requestID,
|
|
Usage: usage,
|
|
Model: originalModel,
|
|
BillingModel: billingModel,
|
|
UpstreamModel: upstreamModel,
|
|
Stream: true,
|
|
Duration: time.Since(startTime),
|
|
FirstTokenMs: firstTokenMs,
|
|
}
|
|
}
|
|
|
|
// processDataLine handles a single "data: ..." SSE line from upstream.
|
|
// Returns (clientDisconnected bool).
|
|
processDataLine := func(payload string) bool {
|
|
if firstChunk {
|
|
firstChunk = false
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|
firstTokenMs = &ms
|
|
}
|
|
|
|
var event apicompat.ResponsesStreamEvent
|
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
logger.L().Warn("openai messages stream: failed to parse event",
|
|
zap.Error(err),
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return false
|
|
}
|
|
|
|
// Extract usage from completion events
|
|
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
event.Response != nil && event.Response.Usage != nil {
|
|
usage = OpenAIUsage{
|
|
InputTokens: event.Response.Usage.InputTokens,
|
|
OutputTokens: event.Response.Usage.OutputTokens,
|
|
}
|
|
if event.Response.Usage.InputTokensDetails != nil {
|
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
}
|
|
|
|
// Convert to Anthropic events
|
|
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
|
for _, evt := range events {
|
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
|
if err != nil {
|
|
logger.L().Warn("openai messages stream: failed to marshal event",
|
|
zap.Error(err),
|
|
zap.String("request_id", requestID),
|
|
)
|
|
continue
|
|
}
|
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
|
logger.L().Info("openai messages stream: client disconnected",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return true
|
|
}
|
|
}
|
|
if len(events) > 0 {
|
|
c.Writer.Flush()
|
|
}
|
|
return false
|
|
}
|
|
|
|
// finalizeStream sends any remaining Anthropic events and returns the result.
|
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
|
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
|
for _, evt := range finalEvents {
|
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
|
}
|
|
c.Writer.Flush()
|
|
}
|
|
return resultWithUsage(), nil
|
|
}
|
|
|
|
// handleScanErr logs scanner errors if meaningful.
|
|
handleScanErr := func(err error) {
|
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
logger.L().Warn("openai messages stream: read error",
|
|
zap.Error(err),
|
|
zap.String("request_id", requestID),
|
|
)
|
|
}
|
|
}
|
|
|
|
// ── Determine keepalive interval ──
|
|
keepaliveInterval := time.Duration(0)
|
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
|
}
|
|
|
|
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
|
if keepaliveInterval <= 0 {
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
continue
|
|
}
|
|
if processDataLine(line[6:]) {
|
|
return resultWithUsage(), nil
|
|
}
|
|
}
|
|
handleScanErr(scanner.Err())
|
|
return finalizeStream()
|
|
}
|
|
|
|
// ── With keepalive: goroutine + channel + select ──
|
|
type scanEvent struct {
|
|
line string
|
|
err error
|
|
}
|
|
events := make(chan scanEvent, 16)
|
|
done := make(chan struct{})
|
|
sendEvent := func(ev scanEvent) bool {
|
|
select {
|
|
case events <- ev:
|
|
return true
|
|
case <-done:
|
|
return false
|
|
}
|
|
}
|
|
go func() {
|
|
defer close(events)
|
|
for scanner.Scan() {
|
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
|
return
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
_ = sendEvent(scanEvent{err: err})
|
|
}
|
|
}()
|
|
defer close(done)
|
|
|
|
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
|
defer keepaliveTicker.Stop()
|
|
lastDataAt := time.Now()
|
|
|
|
for {
|
|
select {
|
|
case ev, ok := <-events:
|
|
if !ok {
|
|
// Upstream closed
|
|
return finalizeStream()
|
|
}
|
|
if ev.err != nil {
|
|
handleScanErr(ev.err)
|
|
return finalizeStream()
|
|
}
|
|
lastDataAt = time.Now()
|
|
line := ev.line
|
|
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
continue
|
|
}
|
|
if processDataLine(line[6:]) {
|
|
return resultWithUsage(), nil
|
|
}
|
|
|
|
case <-keepaliveTicker.C:
|
|
if time.Since(lastDataAt) < keepaliveInterval {
|
|
continue
|
|
}
|
|
// Send Anthropic-format ping event
|
|
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
|
|
// Client disconnected
|
|
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return resultWithUsage(), nil
|
|
}
|
|
c.Writer.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeAnthropicError writes an error response in Anthropic Messages API format.
|
|
func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) {
|
|
c.JSON(statusCode, gin.H{
|
|
"type": "error",
|
|
"error": gin.H{
|
|
"type": errType,
|
|
"message": message,
|
|
},
|
|
})
|
|
}
|