mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Merge pull request #2066 from alfadb/fix/anthropic-stream-eof-failover
fix(gateway): Anthropic 流式 EOF 失败移交 + SSE error 帧标准化
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
mathrand "math/rand"
|
mathrand "math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -6520,6 +6522,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
|
||||||
|
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||||
|
// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
|
||||||
|
// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
|
||||||
|
func sanitizeStreamError(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||||
|
return "unexpected EOF"
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return "EOF"
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
|
return "canceled"
|
||||||
|
case errors.Is(err, context.DeadlineExceeded):
|
||||||
|
return "deadline exceeded"
|
||||||
|
case errors.Is(err, syscall.ECONNRESET):
|
||||||
|
return "connection reset by peer"
|
||||||
|
case errors.Is(err, syscall.ECONNABORTED):
|
||||||
|
return "connection aborted"
|
||||||
|
case errors.Is(err, syscall.ETIMEDOUT):
|
||||||
|
return "connection timed out"
|
||||||
|
case errors.Is(err, syscall.EPIPE):
|
||||||
|
return "broken pipe"
|
||||||
|
case errors.Is(err, syscall.ECONNREFUSED):
|
||||||
|
return "connection refused"
|
||||||
|
}
|
||||||
|
var netErr *net.OpError
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
if netErr.Timeout() {
|
||||||
|
if netErr.Op != "" {
|
||||||
|
return netErr.Op + " timeout"
|
||||||
|
}
|
||||||
|
return "i/o timeout"
|
||||||
|
}
|
||||||
|
if netErr.Op != "" {
|
||||||
|
return netErr.Op + " network error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "upstream connection error"
|
||||||
|
}
|
||||||
|
|
||||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||||
@@ -6957,14 +7002,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
|
||||||
|
// 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":<reason>,"message":<message>}}
|
||||||
|
// 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
|
||||||
|
// 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason, message string) {
|
||||||
if errorEventSent {
|
if errorEventSent {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
if message == "" {
|
||||||
|
message = reason
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": reason,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
|
||||||
|
body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7124,10 +7186,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
// 客户端未断开,正常的错误处理
|
// 客户端未断开,正常的错误处理
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||||
}
|
}
|
||||||
sendErrorEvent("stream_read_error")
|
// 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
|
||||||
|
// 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
|
||||||
|
// 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
|
||||||
|
// 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
|
||||||
|
// 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
|
||||||
|
// 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
|
||||||
|
disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
|
||||||
|
if !c.Writer.Written() {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": "upstream_disconnected",
|
||||||
|
"message": disconnectMsg,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ResponseBody: body,
|
||||||
|
RetryableOnSameAccount: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendErrorEvent("stream_read_error", disconnectMsg)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
}
|
}
|
||||||
line := ev.line
|
line := ev.line
|
||||||
@@ -7186,7 +7270,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
}
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
case <-keepaliveCh:
|
case <-keepaliveCh:
|
||||||
|
|||||||
@@ -4,9 +4,12 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
|
|||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
|
||||||
|
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
|
||||||
|
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
|
||||||
|
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||||
|
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
|
||||||
|
|
||||||
|
// ResponseBody 必须是 Anthropic 标准 error 格式:
|
||||||
|
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
|
||||||
|
// 2) error.type 标记为 upstream_disconnected
|
||||||
|
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
|
||||||
|
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
|
||||||
|
require.Contains(t, extractedMsg, "upstream stream disconnected")
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
|
||||||
|
|
||||||
|
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
|
||||||
|
require.NotContains(t, rec.Body.String(), "stream_read_error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
|
||||||
|
// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
|
||||||
|
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{
|
||||||
|
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
|
||||||
|
err: io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
|
||||||
|
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
|
||||||
|
|
||||||
|
// 不应被错误地包成 failover error
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
|
||||||
|
|
||||||
|
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
|
||||||
|
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
|
||||||
|
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
|
||||||
|
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||||
|
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
|
||||||
|
// failover ResponseBody 或 SSE error 帧返回给客户端。
|
||||||
|
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
|
||||||
|
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||||
|
require.NoError(t, err)
|
||||||
|
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
raw := &net.OpError{
|
||||||
|
Op: "read",
|
||||||
|
Net: "tcp",
|
||||||
|
Source: src,
|
||||||
|
Addr: dst,
|
||||||
|
Err: syscall.ECONNRESET,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
|
||||||
|
require.Contains(t, raw.Error(), "10.0.0.1")
|
||||||
|
require.Contains(t, raw.Error(), "52.1.2.3")
|
||||||
|
|
||||||
|
got := sanitizeStreamError(raw)
|
||||||
|
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
|
||||||
|
require.NotContains(t, got, "54321", "不得泄露源端口")
|
||||||
|
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
|
||||||
|
require.NotContains(t, got, "443", "不得泄露上游端口")
|
||||||
|
require.Equal(t, "connection reset by peer", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
|
||||||
|
{"EOF", io.EOF, "EOF"},
|
||||||
|
{"context canceled", context.Canceled, "canceled"},
|
||||||
|
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
|
||||||
|
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
|
||||||
|
{"EPIPE", syscall.EPIPE, "broken pipe"},
|
||||||
|
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
|
||||||
|
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
|
||||||
|
{"nil 返回空串", nil, ""},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
|
||||||
|
// 时携带内部地址信息。
|
||||||
|
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newMinimalGatewayService()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||||
|
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||||
|
netErr := &net.OpError{
|
||||||
|
Op: "read",
|
||||||
|
Net: "tcp",
|
||||||
|
Source: src,
|
||||||
|
Addr: dst,
|
||||||
|
Err: syscall.ECONNRESET,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{err: netErr},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.True(t, errors.As(err, &failoverErr))
|
||||||
|
|
||||||
|
body := string(failoverErr.ResponseBody)
|
||||||
|
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
|
||||||
|
require.NotContains(t, body, "54321")
|
||||||
|
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
|
||||||
|
require.NotContains(t, body, "443")
|
||||||
|
// 仍然包含可诊断的根因
|
||||||
|
require.Contains(t, body, "connection reset by peer")
|
||||||
|
require.Contains(t, body, "upstream stream disconnected")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user