From 9f8cf0f77be56e50df96e158c321844e0bf1b5b9 Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Fri, 3 Apr 2026 17:50:20 +0200 Subject: [PATCH] implement reader from for the buffered writer --- lib.go | 12 +++++++++ lib_test.go | 8 ++++++ writer.go | 40 +++++++++++++++++++++++++++++ writer_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+) diff --git a/lib.go b/lib.go index a33849a..6f92171 100644 --- a/lib.go +++ b/lib.go @@ -256,6 +256,18 @@ func (w Writer) Write(p []byte) (int, error) { return w.writer.write(p) } +// ReadFrom implements the io.ReaderFrom interface. It copies all the data from the provided reader to the +// underlying writer using a buffer from the pool. +// +// It flushes the buffer at the end of the input stream. +func (w Writer) ReadFrom(r io.Reader) (int64, error) { + if w.writer == nil { + return 0, errors.New("unitialized writer") + } + + return w.writer.readFrom(r) +} + // Flush forces the writer to flush the buffered content to the underlying writer. After flushed, the writer // still accepts further writes. func (w Writer) Flush() error { diff --git a/lib_test.go b/lib_test.go index b537a80..50df273 100644 --- a/lib_test.go +++ b/lib_test.go @@ -320,6 +320,14 @@ func TestLib(t *testing.T) { } }) + t.Run("read from", func(t *testing.T) { + var w buffer.Writer + r := bytes.NewBuffer([]byte{1, 2, 3}) + if n, err := w.ReadFrom(r); n != 0 || err == nil { + t.Fatal(n, err) + } + }) + t.Run("flush", func(t *testing.T) { var w buffer.Writer if err := w.Flush(); err != nil { diff --git a/writer.go b/writer.go index 9d00848..fbe7b29 100644 --- a/writer.go +++ b/writer.go @@ -47,6 +47,46 @@ func (w *writer) write(p []byte) (int, error) { } } +func (w *writer) readFrom(r io.Reader) (int64, error) { + var n int64 + for { + if errors.Is(w.err, io.EOF) { + err := w.err + w.err = nil + w.flush() + if w.err != nil { + return n, w.err + } + + w.err = err + return n, nil + } + + if w.err != nil { + return n, w.err + } + + if len(w.buffer) == 0 { + w.buffer, w.err = w.options.Pool.Get() + if len(w.buffer) == 0 && w.err == nil { + w.err = ErrZeroAllocation + } + + continue + } + + if w.offset+w.len == len(w.buffer) { + w.flush() + continue + } + + var ni int + ni, w.err = r.Read(w.buffer[w.offset+w.len:]) + w.len += ni + n += int64(ni) + } +} + func (w *writer) flush() error { var zeroWrite bool for { diff --git a/writer_test.go b/writer_test.go index f51a5c9..a56022a 100644 --- a/writer_test.go +++ b/writer_test.go @@ -412,4 +412,74 @@ func TestWriter(t *testing.T) { t.Fatal(n, err) } }) + + t.Run("read from", func(t *testing.T) { + t.Run("read out", func(t *testing.T) { + w := &writer{} + r := &gen{max: 1 << 12} + p := buffer.NoPool(1 << 9) + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<12 || err != nil { + t.Fatal(n, err) + } + }) + + t.Run("read out final flush fail", func(t *testing.T) { + w := &writer{errAfter: []int{1<<12 - 1<<9}} + r := &gen{ + max: 1 << 12, + fastErr: true, + } + + p := buffer.NoPool(1 << 9) + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<12 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + }) + + t.Run("read error", func(t *testing.T) { + w := &writer{} + r := &gen{ + max: 1 << 12, + errAfter: []int{1 << 11}, + } + + p := buffer.NoPool(1 << 9) + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<11 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + }) + + t.Run("write error", func(t *testing.T) { + w := &writer{errAfter: []int{1 << 11}} + r := &gen{max: 1 << 12} + p := buffer.NoPool(1 << 9) + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<11+1<<9 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + }) + + t.Run("zero allocation", func(t *testing.T) { + w := &writer{} + r := &gen{max: 1 << 12} + p := &fakePool{} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 0 || !errors.Is(err, buffer.ErrZeroAllocation) { + t.Fatal(n, err) + } + }) + }) }