From e2a24dda248b3f179c1890e6b52ac5cfaa9cce5a Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Fri, 3 Apr 2026 20:56:24 +0200 Subject: [PATCH] improve flushing and closing --- lib_test.go | 126 +++++++++++++++++++++++++++++++++++++++++++++++++ writer.go | 54 +++++++++++---------- writer_test.go | 18 ++++++- 3 files changed, 171 insertions(+), 27 deletions(-) diff --git a/lib_test.go b/lib_test.go index 50df273..fb6daef 100644 --- a/lib_test.go +++ b/lib_test.go @@ -359,6 +359,132 @@ func TestLib(t *testing.T) { } }) }) + + t.Run("closing", func(t *testing.T) { + t.Run("double closing reader", func(t *testing.T) { + g := &gen{max: 1 << 12} + p := &fakePool{allocSize: 1 << 9} + r := buffer.BufferedReader(g, buffer.Options{Pool: p}) + b := bytes.NewBuffer(nil) + if n, err := io.Copy(b, r); n != 1<<12 || err != nil { + t.Fatal(n, err) + } + + r.Close() + r.Close() + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("double closing content reader", func(t *testing.T) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { + var n int64 + for i := 0; i < 3; i++ { + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) + n += int64(ni) + if err != nil { + return n, err + } + } + + return n, nil + }) + + p := &fakePool{allocSize: 1 << 9} + r := buffer.BufferedContent(c, buffer.Options{Pool: p}) + b := bytes.NewBuffer(nil) + if n, err := io.Copy(b, r); n != 9 || err != nil { + t.Fatal(n, err) + } + + r.Close() + r.Close() + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("double closing content reader before eof", func(t *testing.T) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { + var n int64 + for i := 0; i < 3; i++ { + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) + n += int64(ni) + if err != nil { + return n, err + } + } + + return n, nil + }) + + p := &fakePool{allocSize: 1 << 9} + r := buffer.BufferedContent(c, buffer.Options{Pool: p}) + b := make([]byte, 3) + if n, err := r.Read(b); n != 3 || err != nil || string(b) != "123" { + t.Fatal(n, err, string(b)) + } + + r.Close() + r.Close() + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("double closing writer", func(t *testing.T) { + w := &writer{} + r := &gen{max: 1 << 12} + p := &fakePool{allocSize: 1 << 9} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<12 || err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if p.alloc != 1 || p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("closing writer after read error in read from", func(t *testing.T) { + w := &writer{} + r := &gen{ + max: 1 << 12, + errAfter: []int{1 << 11}, + } + + p := &fakePool{allocSize: 1 << 9} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 1<<11 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if p.alloc != 1 || p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + }) } // -- bench diff --git a/writer.go b/writer.go index fbe7b29..7ffb779 100644 --- a/writer.go +++ b/writer.go @@ -50,18 +50,6 @@ func (w *writer) write(p []byte) (int, error) { func (w *writer) readFrom(r io.Reader) (int64, error) { var n int64 for { - if errors.Is(w.err, io.EOF) { - err := w.err - w.err = nil - w.flush() - if w.err != nil { - return n, w.err - } - - w.err = err - return n, nil - } - if w.err != nil { return n, w.err } @@ -80,10 +68,25 @@ func (w *writer) readFrom(r io.Reader) (int64, error) { continue } - var ni int - ni, w.err = r.Read(w.buffer[w.offset+w.len:]) + ni, rerr := r.Read(w.buffer[w.offset+w.len:]) + if ni == 0 && rerr == nil { + ni, rerr = r.Read(w.buffer[w.offset+w.len:]) + } + + if ni == 0 && rerr == nil { + rerr = io.ErrNoProgress + } + w.len += ni n += int64(ni) + if errors.Is(rerr, io.EOF) { + w.flush() + return n, w.err + } + + if rerr != nil { + return n, rerr + } } } @@ -108,25 +111,24 @@ func (w *writer) flush() error { zeroWrite = n == 0 && w.err == nil w.offset += n w.len -= n - if w.err != nil { - w.options.Pool.Put(w.buffer) - w.buffer = nil - } } } func (w *writer) close() error { - if w.err != nil { - return w.err - } - + newErr := w.err == nil w.flush() - if w.err != nil { + if len(w.buffer) > 0 { + w.options.Pool.Put(w.buffer) + w.buffer = nil + } + + if newErr && w.err != nil { return w.err } - w.err = errClosed - w.options.Pool.Put(w.buffer) - w.buffer = nil + if w.err == nil { + w.err = errClosed + } + return nil } diff --git a/writer_test.go b/writer_test.go index a56022a..ab54d85 100644 --- a/writer_test.go +++ b/writer_test.go @@ -358,7 +358,7 @@ func TestWriter(t *testing.T) { t.Fatal(p.alloc, p.free) } - if err := b.Close(); !errors.Is(err, errTest) { + if err := b.Close(); err != nil { t.Fatal(err) } }) @@ -481,5 +481,21 @@ func TestWriter(t *testing.T) { t.Fatal(n, err) } }) + + t.Run("no progress", func(t *testing.T) { + w := &writer{} + r := &gen{ + max: 1 << 15, + nullReadAfter: []int{256, 256}, + } + + p := &fakePool{allocSize: 1 << 9} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + n, err := b.ReadFrom(r) + if n != 512 || !errors.Is(err, io.ErrNoProgress) { + t.Fatal(n, err) + } + }) }) }