Files
sub2api/backend/internal/service/gateway_streaming_test.go
alfadb 4c474616b9 fix(gateway): emit Anthropic-standard SSE error events and failover body
Two follow-ups to PR #2066's failover-wrap fix:

1. Failover ResponseBody (`UpstreamFailoverError.ResponseBody`) was encoded
   as `{"error": "<msg>"}` (string field). `ExtractUpstreamErrorMessage`
   probes for `error.message`, `detail`, or top-level `message` only — so
   `handleFailoverExhausted` and downstream passthrough rules saw an empty
   message, losing the EOF root cause in ops logs. Re-encode as the
   Anthropic standard shape `{"type":"error","error":{"type":"upstream_disconnected","message":"..."}}`.
   (Addresses the inline review comment from copilot-pull-request-reviewer
   on Wei-Shaw/sub2api#2066.)

2. The streaming `event: error` SSE frame for `response_too_large`,
   `stream_read_error`, and `stream_timeout` was non-standard
   (`{"error":"<reason>"}`). Anthropic SDKs (and Claude Code) expect
   `{"type":"error","error":{"type":"...","message":"..."}}` and parse
   `error.type`/`error.message` accordingly. Refactor `sendErrorEvent` to
   take both reason and message, and emit the standard frame so client
   SDKs surface a real diagnostic message instead of a generic stream error.

This does not by itself prevent task interruption on long-stream EOF
(SSE has no resume; client-side retry remains the only complete fix), but
it gives both server-side ops logs and client-side error UIs a meaningful
upstream message so users know the next step is to retry.

Tests updated to assert the new body shape on both branches plus a new
assertion that `ExtractUpstreamErrorMessage` returns a non-empty string.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 20:24:17 +08:00

300 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build unit
package service
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// --- parseSSEUsage 测试 ---
func newMinimalGatewayService() *GatewayService {
return &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
}
func TestParseSSEUsage_MessageStart(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}`
svc.parseSSEUsage(data, usage)
require.Equal(t, 100, usage.InputTokens)
require.Equal(t, 50, usage.CacheCreationInputTokens)
require.Equal(t, 200, usage.CacheReadInputTokens)
require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens")
}
func TestParseSSEUsage_MessageDelta(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
data := `{"type":"message_delta","usage":{"output_tokens":42}}`
svc.parseSSEUsage(data, usage)
require.Equal(t, 42, usage.OutputTokens)
require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens")
}
func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// 先处理 message_start
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage)
require.Equal(t, 100, usage.InputTokens)
// 再处理 message_deltaoutput_tokens > 0, input_tokens = 0
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage)
require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值")
require.Equal(t, 50, usage.OutputTokens)
}
func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// GLM 等 API 会在 delta 中包含所有 usage 信息
svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage)
require.Equal(t, 200, usage.InputTokens)
require.Equal(t, 100, usage.OutputTokens)
require.Equal(t, 30, usage.CacheCreationInputTokens)
require.Equal(t, 60, usage.CacheReadInputTokens)
}
func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// 先在 message_start 中写入非零 5m/1h 明细
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage)
require.Equal(t, 30, usage.CacheCreation5mTokens)
require.Equal(t, 70, usage.CacheCreation1hTokens)
// 后续 delta 带默认 0不应覆盖已有非零值
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage)
require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细")
require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细")
require.Equal(t, 12, usage.OutputTokens)
}
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// 无效 JSON 不应 panic
svc.parseSSEUsage("not json", usage)
require.Equal(t, 0, usage.InputTokens)
require.Equal(t, 0, usage.OutputTokens)
}
func TestParseSSEUsage_UnknownType(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// 不是 message_start 或 message_delta 的类型
svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage)
require.Equal(t, 0, usage.InputTokens)
require.Equal(t, 0, usage.OutputTokens)
}
func TestParseSSEUsage_EmptyString(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
svc.parseSSEUsage("", usage)
require.Equal(t, 0, usage.InputTokens)
}
func TestParseSSEUsage_DoneEvent(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// [DONE] 事件不应影响 usage
svc.parseSSEUsage("[DONE]", usage)
require.Equal(t, 0, usage.InputTokens)
}
// --- 流式响应端到端测试 ---
func TestHandleStreamingResponse_CacheTokens(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 10, result.usage.InputTokens)
require.Equal(t, 15, result.usage.OutputTokens)
require.Equal(t, 20, result.usage.CacheCreationInputTokens)
require.Equal(t, 30, result.usage.CacheReadInputTokens)
}
func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
// 直接关闭,不发送任何事件
_ = pw.Close()
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
// 包含特殊字符的 content_block_delta引号、换行、Unicode
_, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
// 验证响应中包含转发的数据
body := rec.Body.String()
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
}
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF发生在向客户端写入任何字节前
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError期望: %v", err)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
// ResponseBody 必须是 Anthropic 标准 error 格式:
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
// 2) error.type 标记为 upstream_disconnected
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message否则 ops 日志会丢失诊断信息")
require.Contains(t, extractedMsg, "upstream stream disconnected")
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
require.NotContains(t, rec.Body.String(), "stream_read_error")
}
// 上游已经发送过事件c.Writer 已写过字节)后再发生读错误:
// SSE 协议无 resume网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newMinimalGatewayService()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: &streamReadCloser{
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
err: io.ErrUnexpectedEOF,
},
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
require.Error(t, err)
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
// 不应被错误地包成 failover error
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件error.type=stream_read_error
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
body := rec.Body.String()
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段Anthropic 标准)")
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因Claude Code 等客户端才能显示有效错误文案")
}