From 9d7bed320bcdd5d97e44464dacc43ebc0ff08538 Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Wed, 25 Mar 2026 22:47:56 +0100 Subject: [PATCH] implement buffered writer --- io_test.go | 6 + lib.go | 59 ++++++- lib_test.go | 48 ++++++ writer.go | 88 +++++++++++ writer_test.go | 405 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 604 insertions(+), 2 deletions(-) create mode 100644 writer.go create mode 100644 writer_test.go diff --git a/io_test.go b/io_test.go index 4ca6b30..bc86880 100644 --- a/io_test.go +++ b/io_test.go @@ -22,6 +22,7 @@ type gen struct { type writer struct { written []byte errAfter []int + zeroAfter []int shortAfter []int } @@ -137,6 +138,11 @@ func (w *writer) Write(p []byte) (int, error) { return 0, errTest } + if len(p) > 0 && len(w.zeroAfter) > 0 && len(w.written) >= w.zeroAfter[0] { + w.zeroAfter = w.zeroAfter[1:] + return 0, nil + } + if len(p) > 0 && len(w.shortAfter) > 0 && len(w.written) >= w.shortAfter[0] { w.shortAfter = w.shortAfter[1:] p = p[:len(p)/2] diff --git a/lib.go b/lib.go index dadcd41..a33849a 100644 --- a/lib.go +++ b/lib.go @@ -1,7 +1,8 @@ // Package buffer provides pooled Buffer IO for Go programs. // // It implements a reader similar to bufio.Reader. The underlying memory buffers can be used from a synchronized -// pool. +// pool. It implements a writer that can be used to avoid writing too small number of bytes to an underlying +// writer. package buffer import ( @@ -36,7 +37,7 @@ type Options struct { // reader to be executed in goroutines other than what they were created in. type ContentFunc func(io.Writer) (int64, error) -// Reader wraps an underlying io.Reader or io.WriterTo, and provides buffered io, via its methods. Initialize it +// Reader wraps an underlying io.Reader or io.WriterTo, and provides buffered io via its methods. Initialize it // via BufferedReader or BufferedContent. // // It reads from the underlying source until the first error, but only returns an error when the buffer is @@ -47,6 +48,17 @@ type Reader struct { reader *reader } +// Writer wraps an underlying io.Writer, and provides buffered io via its methods. Initialize it via +// BufferedWriter. +// +// It writes the input bytes into an internal buffer, and flushes them to the underlying writer only when the +// buffer is full or when the writer is closed. If necessary, the buffer can be explicitly flushed. +// +// The writer does not support concurrent access. +type Writer struct { + writer *writer +} + var ( // ErrZeroAllocation is returned when the used pool returned a zero length byte slice. ErrZeroAllocation = errors.New("zero allocation") @@ -220,3 +232,46 @@ func (r Reader) Close() { r.reader.free() } + +// BufferedWriter initializes a Writer. +func BufferedWriter(out io.Writer, o Options) Writer { + if out == nil { + return Writer{} + } + + if o.Pool == nil { + o.Pool = DefaultPool(1 << 12) + } + + return Writer{writer: &writer{out: out, options: o}} +} + +// Write writes to the writer's buffer, and if the buffer is full, it causes writing out the buffer's contents +// to the underlying writer. +func (w Writer) Write(p []byte) (int, error) { + if w.writer == nil { + return 0, errors.New("unitialized writer") + } + + return w.writer.write(p) +} + +// Flush forces the writer to flush the buffered content to the underlying writer. After flushed, the writer +// still accepts further writes. +func (w Writer) Flush() error { + if w.writer == nil { + return nil + } + + return w.writer.flush() +} + +// Close flushes the buffered content if any and closes the writer. After closed, the writer does not accept +// further writes. +func (w Writer) Close() error { + if w.writer == nil { + return nil + } + + return w.writer.close() +} diff --git a/lib_test.go b/lib_test.go index 3168ec7..b537a80 100644 --- a/lib_test.go +++ b/lib_test.go @@ -303,6 +303,54 @@ func TestLib(t *testing.T) { } }) }) + + t.Run("writer", func(t *testing.T) { + t.Run("uninitialized writer", func(t *testing.T) { + t.Run("no writer", func(t *testing.T) { + w := buffer.BufferedWriter(nil, buffer.Options{}) + if n, err := w.Write([]byte("123")); n != 0 || err == nil { + t.Fatal(n, err) + } + }) + + t.Run("write", func(t *testing.T) { + var w buffer.Writer + if n, err := w.Write([]byte("123")); n != 0 || err == nil { + t.Fatal(n, err) + } + }) + + t.Run("flush", func(t *testing.T) { + var w buffer.Writer + if err := w.Flush(); err != nil { + t.Fatal(err) + } + }) + + t.Run("close", func(t *testing.T) { + var w buffer.Writer + if err := w.Close(); err != nil { + t.Fatal(err) + } + }) + }) + + t.Run("default pool", func(t *testing.T) { + w := &writer{} + b := buffer.BufferedWriter(w, buffer.Options{}) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123" { + t.Fatal(string(w.written)) + } + }) + }) } // -- bench diff --git a/writer.go b/writer.go new file mode 100644 index 0000000..5ee9a9f --- /dev/null +++ b/writer.go @@ -0,0 +1,88 @@ +package buffer + +import ( + "errors" + "io" +) + +type writer struct { + out io.Writer + options Options + buffer []byte + offset, len int + err error +} + +var errClosed = errors.New("buffer closed") + +func (w *writer) write(p []byte) (int, error) { + var n int + for { + if w.err != nil { + return n, w.err + } + + if len(p) == 0 { + return n, nil + } + + if len(w.buffer) == 0 { + w.buffer, w.err = w.options.Pool.Get() + continue + } + + if w.offset+w.len == len(w.buffer) { + w.flush() + continue + } + + ni := copy(w.buffer[w.offset+w.len:], p) + w.len += ni + p = p[ni:] + n += ni + } +} + +func (w *writer) flush() error { + var zeroWrite bool + for { + if w.err != nil { + return w.err + } + + if w.len == 0 { + w.offset = 0 + return nil + } + + var n int + n, w.err = w.out.Write(w.buffer[w.offset : w.offset+w.len]) + if n == 0 && w.err == nil && zeroWrite { + w.err = io.ErrShortWrite + } + + zeroWrite = n == 0 && w.err == nil + w.offset += n + w.len -= n + if w.err != nil { + w.options.Pool.Put(w.buffer) + w.buffer = nil + } + } +} + +func (w *writer) close() error { + if w.err != nil { + return w.err + } + + w.flush() + if w.err != nil { + return w.err + } + + w.err = errClosed + w.options.Pool.Put(w.buffer) + w.buffer = nil + return nil +} diff --git a/writer_test.go b/writer_test.go new file mode 100644 index 0000000..6ba777c --- /dev/null +++ b/writer_test.go @@ -0,0 +1,405 @@ +package buffer_test + +import ( + "code.squareroundforest.org/arpio/buffer" + "errors" + "io" + "testing" +) + +func TestWriter(t *testing.T) { + t.Run("write out", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(32)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("789")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456789" { + t.Fatal(string(w.written)) + } + }) + + t.Run("zero bytes when empty", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(32)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write(nil); n != 0 || err != nil { + t.Fatal(n, err) + } + }) + + t.Run("zero bytes when erred", func(t *testing.T) { + w := &writer{errAfter: []int{3}} + o := buffer.Options{Pool: buffer.NoPool(2)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if n, err := b.Write(nil); n != 0 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if string(w.written) != "1234" { + t.Fatal(string(w.written)) + } + }) + + t.Run("zero bytes when not empty", func(t *testing.T) { + w := &writer{errAfter: []int{3}} + o := buffer.Options{Pool: buffer.NoPool(32)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write(nil); n != 0 || err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + }) + + t.Run("get buffer fails", func(t *testing.T) { + w := &writer{} + p := &fakePool{ + allocSize: 32, + errAfter: []int{0}, + } + + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 0 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + }) + + t.Run("no underlying write until full", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + if string(w.written) != "1234" { + t.Fatal(string(w.written)) + } + + if n, err := b.Write([]byte("789")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if string(w.written) != "12345678" { + t.Fatal(string(w.written)) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456789" { + t.Fatal(string(w.written)) + } + }) + + t.Run("auto flush when write larger than buffer", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123456")); err != nil { + t.Fatal(n, err) + } + + if string(w.written) != "1234" { + t.Fatal(string(w.written)) + } + }) + + t.Run("auto flush when multiple smaller writes fill the buffer", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + + if n, err := b.Write([]byte("456")); err != nil { + t.Fatal(n, err) + } + + if string(w.written) != "1234" { + t.Fatal(string(w.written)) + } + }) + + t.Run("flush on close", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123" { + t.Fatal(string(w.written)) + } + }) + + t.Run("manual flush", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + + if err := b.Flush(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123" { + t.Fatal(string(w.written)) + } + }) + + t.Run("write after flush", func(t *testing.T) { + w := &writer{} + o := buffer.Options{Pool: buffer.NoPool(4)} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if len(w.written) != 0 { + t.Fatal(string(w.written)) + } + + if err := b.Flush(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123" { + t.Fatal(string(w.written)) + } + + if n, err := b.Write([]byte("456")); err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456" { + t.Fatal(string(w.written)) + } + }) + + t.Run("zero write recover", func(t *testing.T) { + w := &writer{zeroAfter: []int{2}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456" { + t.Fatal(string(w.written)) + } + }) + + t.Run("zero write terminal", func(t *testing.T) { + w := &writer{zeroAfter: []int{2, 2}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 1 || !errors.Is(err, io.ErrShortWrite) { + t.Fatal(n, err) + } + }) + + t.Run("partial write", func(t *testing.T) { + w := &writer{shortAfter: []int{2}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456" { + t.Fatal(string(w.written)) + } + }) + + t.Run("buffer released on write error immediately", func(t *testing.T) { + w := &writer{errAfter: []int{0}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 2 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("buffer released on write error", func(t *testing.T) { + w := &writer{errAfter: []int{3}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + }) + + t.Run("close on err", func(t *testing.T) { + w := &writer{errAfter: []int{3}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { + t.Fatal(n, err) + } + + if p.alloc != 1 && p.free != 1 { + t.Fatal(p.alloc, p.free) + } + + if err := b.Close(); !errors.Is(err, errTest) { + t.Fatal(err) + } + }) + + t.Run("close and flush err", func(t *testing.T) { + w := &writer{errAfter: []int{3}} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); !errors.Is(err, errTest) { + t.Fatal(err) + } + }) + + t.Run("close", func(t *testing.T) { + w := &writer{} + p := &fakePool{allocSize: 2} + o := buffer.Options{Pool: p} + b := buffer.BufferedWriter(w, o) + if n, err := b.Write([]byte("123")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if n, err := b.Write([]byte("456")); n != 3 || err != nil { + t.Fatal(n, err) + } + + if err := b.Close(); err != nil { + t.Fatal(err) + } + + if string(w.written) != "123456" { + t.Fatal(string(w.written)) + } + }) +}