diff --git a/content_test.go b/content_test.go index 627876a..3ccd9ab 100644 --- a/content_test.go +++ b/content_test.go @@ -1,43 +1,18 @@ -package buffer +package buffer_test import ( - "testing" - "io" + "code.squareroundforest.org/arpio/buffer" "errors" + "io" + "testing" ) -type testPool struct { - size int - failAfter []int - count int -} - -var ( - errTest = errors.New("test error") - errTest2 = errors.New("test error 2") -) - -func (p *testPool) Get() ([]byte, error) { - defer func() { - p.count++ - }() - - if len(p.failAfter) > 0 && p.count == p.failAfter[0] { - p.failAfter = p.failAfter[1:] - return nil, errTest - } - - return make([]byte, p.size), nil -} - -func (p *testPool) Put([]byte) {} - func TestContent(t *testing.T) { t.Run("eof", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { var n int64 for i := 0; i < 3; i++ { - ni, err := w.Write([]byte("123456789")[i * 3:i * 3 + 3]) + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) n += int64(ni) if err != nil { return n, err @@ -47,9 +22,9 @@ func TestContent(t *testing.T) { return n, nil }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) for i := 0; i < 3; i++ { n, err := r.Read(b) @@ -57,7 +32,7 @@ func TestContent(t *testing.T) { t.Fatal(n, err) } - if string(b) != "123456789"[i * 3:i * 3 + 3] { + if string(b) != "123456789"[i*3:i*3+3] { t.Fatal(string(b)) } } @@ -69,13 +44,13 @@ func TestContent(t *testing.T) { }) t.Run("eof right away", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { return 0, nil }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) n, err := r.Read(b) if n != 0 || !errors.Is(err, io.EOF) { @@ -84,10 +59,10 @@ func TestContent(t *testing.T) { }) t.Run("writer error", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { var n int64 for i := 0; i < 3; i++ { - ni, err := w.Write([]byte("123456789")[i * 3:i * 3 + 3]) + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) n += int64(ni) if err != nil { return n, err @@ -97,9 +72,9 @@ func TestContent(t *testing.T) { return n, errTest }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) for i := 0; i < 3; i++ { n, err := r.Read(b) @@ -107,7 +82,7 @@ func TestContent(t *testing.T) { t.Fatal(n, err) } - if string(b) != "123456789"[i * 3:i * 3 + 3] { + if string(b) != "123456789"[i*3:i*3+3] { t.Fatal(string(b)) } } @@ -119,13 +94,13 @@ func TestContent(t *testing.T) { }) t.Run("writer error right away", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { return 0, errTest }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) n, err := r.Read(b) if n != 0 || !errors.Is(err, errTest) { @@ -134,10 +109,10 @@ func TestContent(t *testing.T) { }) t.Run("abort", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { var n int64 for i := 0; i < 3; i++ { - ni, err := w.Write([]byte("123456789")[i * 3:i * 3 + 3]) + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) n += int64(ni) if err != nil { return n, err @@ -147,13 +122,13 @@ func TestContent(t *testing.T) { return n, nil }) - p := &testPool{ - size: 2, - failAfter: []int{1}, + p := &pool{ + allocSize: 2, + errAfter: []int{1}, } - o := Options{Pool: p} - r := BufferedContent(c, o) + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b, ok, err := r.ReadBytes([]byte("67"), 12) if string(b) != "12" /* segment size og 2 by the pool */ || ok || err != nil { t.Fatal(string(b), ok, err) @@ -166,10 +141,10 @@ func TestContent(t *testing.T) { }) t.Run("abort right away", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { var n int64 for i := 0; i < 3; i++ { - ni, err := w.Write([]byte("123456789")[i * 3:i * 3 + 3]) + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) n += int64(ni) if err != nil { return n, err @@ -179,13 +154,13 @@ func TestContent(t *testing.T) { return n, nil }) - p := &testPool{ - size: 2, - failAfter: []int{0}, + p := &pool{ + allocSize: 2, + errAfter: []int{0}, } - o := Options{Pool: p} - r := BufferedContent(c, o) + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b, ok, err := r.ReadBytes([]byte("67"), 12) if len(b) != 0 || ok || !errors.Is(err, errTest) { t.Fatal(string(b), ok, err) @@ -193,20 +168,20 @@ func TestContent(t *testing.T) { }) t.Run("close when implementation ignores writer errors", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { w.Write([]byte("123")) w.Write([]byte("456")) w.Write([]byte("123")) return 0, nil }) - p := &testPool{ - size: 2, - failAfter: []int{1}, + p := &pool{ + allocSize: 2, + errAfter: []int{1}, } - o := Options{Pool: p} - r := BufferedContent(c, o) + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b, ok, err := r.ReadBytes([]byte("67"), 12) if string(b) != "12" /* segment size og 2 by the pool */ || ok || err != nil { t.Fatal(string(b), ok, err) @@ -219,7 +194,7 @@ func TestContent(t *testing.T) { }) t.Run("zero write", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { w.Write([]byte("123")) w.Write(nil) w.Write([]byte("456")) @@ -227,9 +202,9 @@ func TestContent(t *testing.T) { return 0, nil }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) for i := 0; i < 3; i++ { n, err := r.Read(b) @@ -237,7 +212,7 @@ func TestContent(t *testing.T) { t.Fatal(n, err) } - if string(b) != "123456789"[i * 3:i * 3 + 3] { + if string(b) != "123456789"[i*3:i*3+3] { t.Fatal(string(b)) } } @@ -249,7 +224,7 @@ func TestContent(t *testing.T) { }) t.Run("zero write right away", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { w.Write(nil) w.Write([]byte("123")) w.Write([]byte("456")) @@ -257,9 +232,9 @@ func TestContent(t *testing.T) { return 0, nil }) - p := &testPool{size: 2} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 2} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) for i := 0; i < 3; i++ { n, err := r.Read(b) @@ -267,7 +242,7 @@ func TestContent(t *testing.T) { t.Fatal(n, err) } - if string(b) != "123456789"[i * 3:i * 3 + 3] { + if string(b) != "123456789"[i*3:i*3+3] { t.Fatal(string(b)) } } @@ -279,10 +254,10 @@ func TestContent(t *testing.T) { }) t.Run("custom error", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { var n int64 for i := 0; i < 3; i++ { - ni, err := w.Write([]byte("123456789")[i * 3:i * 3 + 3]) + ni, err := w.Write([]byte("123456789")[i*3 : i*3+3]) n += int64(ni) if err != nil { return n, err @@ -292,9 +267,9 @@ func TestContent(t *testing.T) { return n, errTest }) - p := &testPool{size: 3} - o := Options{Pool: p} - r := BufferedContent(c, o) + p := &pool{allocSize: 3} + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b := make([]byte, 3) for i := 0; i < 3; i++ { n, err := r.Read(b) @@ -302,7 +277,7 @@ func TestContent(t *testing.T) { t.Fatal(n, err) } - if string(b) != "123456789"[i * 3:i * 3 + 3] { + if string(b) != "123456789"[i*3:i*3+3] { t.Fatal(string(b)) } } @@ -314,20 +289,20 @@ func TestContent(t *testing.T) { }) t.Run("custom error with pool error", func(t *testing.T) { - c := ContentFunc(func(w io.Writer) (int64, error) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { w.Write([]byte("123")) w.Write([]byte("456")) w.Write([]byte("123")) return 0, errTest2 }) - p := &testPool{ - size: 2, - failAfter: []int{1}, + p := &pool{ + allocSize: 2, + errAfter: []int{1}, } - o := Options{Pool: p} - r := BufferedContent(c, o) + o := buffer.Options{Pool: p} + r := buffer.BufferedContent(c, o) b, ok, err := r.ReadBytes([]byte("67"), 12) if string(b) != "12" /* segment size og 2 by the pool */ || ok || err != nil { t.Fatal(string(b), ok, err) diff --git a/io_test.go b/io_test.go index d67515d..4ca6b30 100644 --- a/io_test.go +++ b/io_test.go @@ -30,6 +30,7 @@ var ( utf8Range = []byte("aábéícóöődúüeű") utf8W2Range = []byte("áéíóöőúüű") errTest = errors.New("test error") + errTest2 = errors.New("test error 2") ) func (g *gen) Read(p []byte) (int, error) { diff --git a/lib.go b/lib.go index 6069625..801a8a9 100644 --- a/lib.go +++ b/lib.go @@ -30,11 +30,11 @@ type Reader struct { var ( ErrZeroAllocation = errors.New("zero allocation") - ErrContentAbort = errors.New("content pipe aborted") + ErrContentAbort = errors.New("content pipe aborted") ) -func DefaultPool() Pool { - return newPool() +func DefaultPool(allocSize int) Pool { + return newPool(allocSize) } func NoPool(allocSize int) Pool { @@ -52,7 +52,7 @@ func BufferedReader(in io.Reader, o Options) Reader { } if o.Pool == nil { - o.Pool = DefaultPool() + o.Pool = DefaultPool(1 << 12) } return Reader{reader: &reader{options: o, in: in}} @@ -62,7 +62,7 @@ func BufferedReader(in io.Reader, o Options) Reader { // the individual Write calls are blocked until the reading end requests more data // WriterTo instances need to be safe to call in goroutines other than they were created in // if it returns with nil error, it will be interpreted as EOF on the reader side -// unfinished calls to the passed in io.Writer will return with ErrContentAbort when the buffer +// unfinished calls to the passed in io.Writer will return with ErrContentAbort when the buffer // needs to be consumed func BufferedContent(c io.WriterTo, o Options) Reader { if c == nil { @@ -70,7 +70,7 @@ func BufferedContent(c io.WriterTo, o Options) Reader { } if o.Pool == nil { - o.Pool = DefaultPool() + o.Pool = DefaultPool(1 << 12) } return Reader{reader: &reader{options: o, in: mkcontent(c)}} diff --git a/lib_test.go b/lib_test.go new file mode 100644 index 0000000..660c5a4 --- /dev/null +++ b/lib_test.go @@ -0,0 +1,152 @@ +package buffer_test + +import ( + "bytes" + "code.squareroundforest.org/arpio/buffer" + "errors" + "io" + "testing" +) + +func TestLib(t *testing.T) { + t.Run("default pool", func(t *testing.T) { + t.Run("buffered reader", func(t *testing.T) { + g := &gen{max: 1 << 18} + r := buffer.BufferedReader(g, buffer.Options{}) + b, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(b, generate(1<<18)) { + t.Fatal("output does not match", len(b)) + } + }) + + t.Run("buffered content", func(t *testing.T) { + c := buffer.ContentFunc(func(w io.Writer) (int64, error) { + g := &gen{max: 1 << 18} + return io.Copy(w, g) + }) + + r := buffer.BufferedContent(c, buffer.Options{}) + b, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(b, generate(1<<18)) { + t.Fatal("output does not match", len(b)) + } + }) + }) + + t.Run("zero reader", func(t *testing.T) { + t.Run("buffered reader", func(t *testing.T) { + r := buffer.BufferedReader(nil, buffer.Options{}) + b, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if len(b) != 0 { + t.Fatal("output does not match", len(b)) + } + }) + + t.Run("buffered content", func(t *testing.T) { + r := buffer.BufferedContent(nil, buffer.Options{}) + b, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if len(b) != 0 { + t.Fatal("output does not match", len(b)) + } + }) + }) + + t.Run("uninitialized reader", func(t *testing.T) { + t.Run("read", func(t *testing.T) { + var r buffer.Reader + p := make([]byte, 512) + n, err := r.Read(p) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + + if n != 0 { + t.Fatal(n) + } + }) + + t.Run("read bytes", func(t *testing.T) { + var r buffer.Reader + b, ok, err := r.ReadBytes([]byte("123"), 512) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + + if ok { + t.Fatal(ok) + } + + if len(b) != 0 { + t.Fatal(len(b)) + } + }) + + t.Run("read utf8", func(t *testing.T) { + var r buffer.Reader + runes, n, err := r.ReadUTF8(512) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + + if n != 0 { + t.Fatal(n) + } + + if len(runes) != 0 { + t.Fatal(len(runes)) + } + }) + + t.Run("peek", func(t *testing.T) { + var r buffer.Reader + b, err := r.Peek(512) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + + if len(b) != 0 { + t.Fatal(len(b)) + } + }) + + t.Run("buffered", func(t *testing.T) { + var r buffer.Reader + b := r.Buffered() + if len(b) != 0 { + t.Fatal(len(b)) + } + }) + + t.Run("write to", func(t *testing.T) { + var ( + r buffer.Reader + b bytes.Buffer + ) + + n, err := r.WriteTo(&b) + if err != nil { + t.Fatal(err) + } + + if n != 0 { + t.Fatal(n) + } + }) + }) +} diff --git a/pool.go b/pool.go index 4be2f9e..fdf4a6a 100644 --- a/pool.go +++ b/pool.go @@ -1,13 +1,23 @@ package buffer +import "sync" + type noPool struct { allocSize int } -type pool struct{} +type pool struct { + sp *sync.Pool +} -func newPool() *pool { - return &pool{} +func newPool(allocSize int) *pool { + sp := &sync.Pool{ + New: func() any { + return make([]byte, allocSize) + }, + } + + return &pool{sp: sp} } func (p noPool) Get() ([]byte, error) { @@ -18,8 +28,9 @@ func (noPool) Put([]byte) { } func (p *pool) Get() ([]byte, error) { - return nil, nil + return p.sp.Get().([]byte), nil } func (p *pool) Put(b []byte) { + p.sp.Put(b) } diff --git a/pool_test.go b/pool_test.go index 7dccbbe..5fa5d78 100644 --- a/pool_test.go +++ b/pool_test.go @@ -836,3 +836,16 @@ func TestPoolUsage(t *testing.T) { }) } } + +func TestDefaultPool(t *testing.T) { + g := &gen{max: 1 << 18} + r := buffer.BufferedReader(g, buffer.Options{}) + b, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(b, generate(1<<18)) { + t.Fatal("output does not match", len(b)) + } +}