diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go index 69e99dc5..31bba8c5 100644 --- a/backend/internal/pkg/httputil/body.go +++ b/backend/internal/pkg/httputil/body.go @@ -2,8 +2,15 @@ package httputil import ( "bytes" + "compress/gzip" + "compress/zlib" + "errors" + "fmt" "io" "net/http" + "strings" + + "github.com/klauspost/compress/zstd" ) const ( @@ -11,7 +18,9 @@ const ( requestBodyReadMaxInitCap = 1 << 20 ) -// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based +// on content length, transparently decoding any Content-Encoding the upstream +// client used to compress the body (zstd, gzip, deflate). func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { if req == nil || req.Body == nil { return nil, nil @@ -33,5 +42,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { if _, err := io.Copy(buf, req.Body); err != nil { return nil, err } - return buf.Bytes(), nil + raw := buf.Bytes() + + enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding"))) + if enc == "" || enc == "identity" { + return raw, nil + } + + decoded, err := decompressRequestBody(enc, raw) + if err != nil { + return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err) + } + + req.Header.Del("Content-Encoding") + req.Header.Del("Content-Length") + req.ContentLength = int64(len(decoded)) + + return decoded, nil +} + +func decompressRequestBody(encoding string, raw []byte) ([]byte, error) { + switch encoding { + case "zstd": + dec, err := zstd.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer dec.Close() + return io.ReadAll(dec) + case "gzip", "x-gzip": + gr, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer gr.Close() + return io.ReadAll(gr) + case "deflate": + zr, err := zlib.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer zr.Close() + return io.ReadAll(zr) + default: + return nil, errors.New("unsupported Content-Encoding") + } } diff --git a/backend/internal/pkg/httputil/body_test.go b/backend/internal/pkg/httputil/body_test.go new file mode 100644 index 00000000..ed8355d5 --- /dev/null +++ b/backend/internal/pkg/httputil/body_test.go @@ -0,0 +1,143 @@ +package httputil + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "net/http" + "strings" + "testing" + + "github.com/klauspost/compress/zstd" +) + +const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}` + +func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request { + t.Helper() + req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if encoding != "" { + req.Header.Set("Content-Encoding", encoding) + } + req.ContentLength = int64(len(body)) + return req +} + +func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) { + enc, _ := zstd.NewWriter(nil) + compressed := enc.EncodeAll([]byte(samplePayload), nil) + _ = enc.Close() + + req := newRequestWithBody(t, compressed, "zstd") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } + if req.Header.Get("Content-Encoding") != "" { + t.Fatalf("Content-Encoding should be cleared after decoding") + } + if req.ContentLength != int64(len(samplePayload)) { + t.Fatalf("ContentLength not updated: %d", req.ContentLength) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write([]byte(samplePayload)); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + + req := newRequestWithBody(t, buf.Bytes(), "gzip") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + if _, err := zw.Write([]byte(samplePayload)); err != nil { + t.Fatalf("zlib write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("zlib close: %v", err) + } + + req := newRequestWithBody(t, buf.Bytes(), "deflate") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "br") + _, err := ReadRequestBodyWithPrealloc(req) + if err == nil { + t.Fatal("expected error for unsupported encoding, got nil") + } + if !strings.Contains(err.Error(), "br") { + t.Fatalf("error should mention encoding, got %v", err) + } +} + +func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) { + req := newRequestWithBody(t, []byte("not actually zstd"), "zstd") + _, err := ReadRequestBodyWithPrealloc(req) + if err == nil { + t.Fatal("expected error for corrupt zstd body, got nil") + } +} + +func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Fatalf("expected nil body, got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "identity") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +}