1
0

implement buffered writer

This commit is contained in:
Arpad Ryszka 2026-03-25 22:47:56 +01:00
parent 8dd8a636af
commit 9d7bed320b
5 changed files with 604 additions and 2 deletions

View File

@ -22,6 +22,7 @@ type gen struct {
type writer struct { type writer struct {
written []byte written []byte
errAfter []int errAfter []int
zeroAfter []int
shortAfter []int shortAfter []int
} }
@ -137,6 +138,11 @@ func (w *writer) Write(p []byte) (int, error) {
return 0, errTest return 0, errTest
} }
if len(p) > 0 && len(w.zeroAfter) > 0 && len(w.written) >= w.zeroAfter[0] {
w.zeroAfter = w.zeroAfter[1:]
return 0, nil
}
if len(p) > 0 && len(w.shortAfter) > 0 && len(w.written) >= w.shortAfter[0] { if len(p) > 0 && len(w.shortAfter) > 0 && len(w.written) >= w.shortAfter[0] {
w.shortAfter = w.shortAfter[1:] w.shortAfter = w.shortAfter[1:]
p = p[:len(p)/2] p = p[:len(p)/2]

59
lib.go
View File

@ -1,7 +1,8 @@
// Package buffer provides pooled Buffer IO for Go programs. // Package buffer provides pooled Buffer IO for Go programs.
// //
// It implements a reader similar to bufio.Reader. The underlying memory buffers can be used from a synchronized // It implements a reader similar to bufio.Reader. The underlying memory buffers can be used from a synchronized
// pool. // pool. It implements a writer that can be used to avoid writing too small number of bytes to an underlying
// writer.
package buffer package buffer
import ( import (
@ -36,7 +37,7 @@ type Options struct {
// reader to be executed in goroutines other than what they were created in. // reader to be executed in goroutines other than what they were created in.
type ContentFunc func(io.Writer) (int64, error) type ContentFunc func(io.Writer) (int64, error)
// Reader wraps an underlying io.Reader or io.WriterTo, and provides buffered io, via its methods. Initialize it // Reader wraps an underlying io.Reader or io.WriterTo, and provides buffered io via its methods. Initialize it
// via BufferedReader or BufferedContent. // via BufferedReader or BufferedContent.
// //
// It reads from the underlying source until the first error, but only returns an error when the buffer is // It reads from the underlying source until the first error, but only returns an error when the buffer is
@ -47,6 +48,17 @@ type Reader struct {
reader *reader reader *reader
} }
// Writer wraps an underlying io.Writer, and provides buffered io via its methods. Initialize it via
// BufferedWriter.
//
// It writes the input bytes into an internal buffer, and flushes them to the underlying writer only when the
// buffer is full or when the writer is closed. If necessary, the buffer can be explicitly flushed.
//
// The writer does not support concurrent access.
type Writer struct {
writer *writer
}
var ( var (
// ErrZeroAllocation is returned when the used pool returned a zero length byte slice. // ErrZeroAllocation is returned when the used pool returned a zero length byte slice.
ErrZeroAllocation = errors.New("zero allocation") ErrZeroAllocation = errors.New("zero allocation")
@ -220,3 +232,46 @@ func (r Reader) Close() {
r.reader.free() r.reader.free()
} }
// BufferedWriter initializes a Writer.
func BufferedWriter(out io.Writer, o Options) Writer {
if out == nil {
return Writer{}
}
if o.Pool == nil {
o.Pool = DefaultPool(1 << 12)
}
return Writer{writer: &writer{out: out, options: o}}
}
// Write writes to the writer's buffer, and if the buffer is full, it causes writing out the buffer's contents
// to the underlying writer.
func (w Writer) Write(p []byte) (int, error) {
if w.writer == nil {
return 0, errors.New("unitialized writer")
}
return w.writer.write(p)
}
// Flush forces the writer to flush the buffered content to the underlying writer. After flushed, the writer
// still accepts further writes.
func (w Writer) Flush() error {
if w.writer == nil {
return nil
}
return w.writer.flush()
}
// Close flushes the buffered content if any and closes the writer. After closed, the writer does not accept
// further writes.
func (w Writer) Close() error {
if w.writer == nil {
return nil
}
return w.writer.close()
}

View File

@ -303,6 +303,54 @@ func TestLib(t *testing.T) {
} }
}) })
}) })
t.Run("writer", func(t *testing.T) {
t.Run("uninitialized writer", func(t *testing.T) {
t.Run("no writer", func(t *testing.T) {
w := buffer.BufferedWriter(nil, buffer.Options{})
if n, err := w.Write([]byte("123")); n != 0 || err == nil {
t.Fatal(n, err)
}
})
t.Run("write", func(t *testing.T) {
var w buffer.Writer
if n, err := w.Write([]byte("123")); n != 0 || err == nil {
t.Fatal(n, err)
}
})
t.Run("flush", func(t *testing.T) {
var w buffer.Writer
if err := w.Flush(); err != nil {
t.Fatal(err)
}
})
t.Run("close", func(t *testing.T) {
var w buffer.Writer
if err := w.Close(); err != nil {
t.Fatal(err)
}
})
})
t.Run("default pool", func(t *testing.T) {
w := &writer{}
b := buffer.BufferedWriter(w, buffer.Options{})
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123" {
t.Fatal(string(w.written))
}
})
})
} }
// -- bench // -- bench

88
writer.go Normal file
View File

@ -0,0 +1,88 @@
package buffer
import (
"errors"
"io"
)
type writer struct {
out io.Writer
options Options
buffer []byte
offset, len int
err error
}
var errClosed = errors.New("buffer closed")
func (w *writer) write(p []byte) (int, error) {
var n int
for {
if w.err != nil {
return n, w.err
}
if len(p) == 0 {
return n, nil
}
if len(w.buffer) == 0 {
w.buffer, w.err = w.options.Pool.Get()
continue
}
if w.offset+w.len == len(w.buffer) {
w.flush()
continue
}
ni := copy(w.buffer[w.offset+w.len:], p)
w.len += ni
p = p[ni:]
n += ni
}
}
func (w *writer) flush() error {
var zeroWrite bool
for {
if w.err != nil {
return w.err
}
if w.len == 0 {
w.offset = 0
return nil
}
var n int
n, w.err = w.out.Write(w.buffer[w.offset : w.offset+w.len])
if n == 0 && w.err == nil && zeroWrite {
w.err = io.ErrShortWrite
}
zeroWrite = n == 0 && w.err == nil
w.offset += n
w.len -= n
if w.err != nil {
w.options.Pool.Put(w.buffer)
w.buffer = nil
}
}
}
func (w *writer) close() error {
if w.err != nil {
return w.err
}
w.flush()
if w.err != nil {
return w.err
}
w.err = errClosed
w.options.Pool.Put(w.buffer)
w.buffer = nil
return nil
}

405
writer_test.go Normal file
View File

@ -0,0 +1,405 @@
package buffer_test
import (
"code.squareroundforest.org/arpio/buffer"
"errors"
"io"
"testing"
)
func TestWriter(t *testing.T) {
t.Run("write out", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(32)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("789")); n != 3 || err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456789" {
t.Fatal(string(w.written))
}
})
t.Run("zero bytes when empty", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(32)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write(nil); n != 0 || err != nil {
t.Fatal(n, err)
}
})
t.Run("zero bytes when erred", func(t *testing.T) {
w := &writer{errAfter: []int{3}}
o := buffer.Options{Pool: buffer.NoPool(2)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
if n, err := b.Write(nil); n != 0 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
if string(w.written) != "1234" {
t.Fatal(string(w.written))
}
})
t.Run("zero bytes when not empty", func(t *testing.T) {
w := &writer{errAfter: []int{3}}
o := buffer.Options{Pool: buffer.NoPool(32)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write(nil); n != 0 || err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
})
t.Run("get buffer fails", func(t *testing.T) {
w := &writer{}
p := &fakePool{
allocSize: 32,
errAfter: []int{0},
}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 0 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
})
t.Run("no underlying write until full", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if string(w.written) != "1234" {
t.Fatal(string(w.written))
}
if n, err := b.Write([]byte("789")); n != 3 || err != nil {
t.Fatal(n, err)
}
if string(w.written) != "12345678" {
t.Fatal(string(w.written))
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456789" {
t.Fatal(string(w.written))
}
})
t.Run("auto flush when write larger than buffer", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123456")); err != nil {
t.Fatal(n, err)
}
if string(w.written) != "1234" {
t.Fatal(string(w.written))
}
})
t.Run("auto flush when multiple smaller writes fill the buffer", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
if n, err := b.Write([]byte("456")); err != nil {
t.Fatal(n, err)
}
if string(w.written) != "1234" {
t.Fatal(string(w.written))
}
})
t.Run("flush on close", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123" {
t.Fatal(string(w.written))
}
})
t.Run("manual flush", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
if err := b.Flush(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123" {
t.Fatal(string(w.written))
}
})
t.Run("write after flush", func(t *testing.T) {
w := &writer{}
o := buffer.Options{Pool: buffer.NoPool(4)}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if len(w.written) != 0 {
t.Fatal(string(w.written))
}
if err := b.Flush(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123" {
t.Fatal(string(w.written))
}
if n, err := b.Write([]byte("456")); err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456" {
t.Fatal(string(w.written))
}
})
t.Run("zero write recover", func(t *testing.T) {
w := &writer{zeroAfter: []int{2}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456" {
t.Fatal(string(w.written))
}
})
t.Run("zero write terminal", func(t *testing.T) {
w := &writer{zeroAfter: []int{2, 2}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 1 || !errors.Is(err, io.ErrShortWrite) {
t.Fatal(n, err)
}
})
t.Run("partial write", func(t *testing.T) {
w := &writer{shortAfter: []int{2}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456" {
t.Fatal(string(w.written))
}
})
t.Run("buffer released on write error immediately", func(t *testing.T) {
w := &writer{errAfter: []int{0}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 2 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
if p.alloc != 1 && p.free != 1 {
t.Fatal(p.alloc, p.free)
}
})
t.Run("buffer released on write error", func(t *testing.T) {
w := &writer{errAfter: []int{3}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
if p.alloc != 1 && p.free != 1 {
t.Fatal(p.alloc, p.free)
}
})
t.Run("close on err", func(t *testing.T) {
w := &writer{errAfter: []int{3}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) {
t.Fatal(n, err)
}
if p.alloc != 1 && p.free != 1 {
t.Fatal(p.alloc, p.free)
}
if err := b.Close(); !errors.Is(err, errTest) {
t.Fatal(err)
}
})
t.Run("close and flush err", func(t *testing.T) {
w := &writer{errAfter: []int{3}}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if err := b.Close(); !errors.Is(err, errTest) {
t.Fatal(err)
}
})
t.Run("close", func(t *testing.T) {
w := &writer{}
p := &fakePool{allocSize: 2}
o := buffer.Options{Pool: p}
b := buffer.BufferedWriter(w, o)
if n, err := b.Write([]byte("123")); n != 3 || err != nil {
t.Fatal(n, err)
}
if n, err := b.Write([]byte("456")); n != 3 || err != nil {
t.Fatal(n, err)
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
if string(w.written) != "123456" {
t.Fatal(string(w.written))
}
})
}