package buffer_test import ( "code.squareroundforest.org/arpio/buffer" "errors" "io" "testing" ) func TestWriter(t *testing.T) { t.Run("write out", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(32)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("789")); n != 3 || err != nil { t.Fatal(n, err) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456789" { t.Fatal(string(w.written)) } }) t.Run("zero bytes when empty", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(32)} b := buffer.BufferedWriter(w, o) if n, err := b.Write(nil); n != 0 || err != nil { t.Fatal(n, err) } }) t.Run("zero bytes when erred", func(t *testing.T) { w := &writer{errAfter: []int{3}} o := buffer.Options{Pool: buffer.NoPool(2)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { t.Fatal(n, err) } if n, err := b.Write(nil); n != 0 || !errors.Is(err, errTest) { t.Fatal(n, err) } if string(w.written) != "1234" { t.Fatal(string(w.written)) } }) t.Run("zero bytes when not empty", func(t *testing.T) { w := &writer{errAfter: []int{3}} o := buffer.Options{Pool: buffer.NoPool(32)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write(nil); n != 0 || err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } }) t.Run("get buffer fails", func(t *testing.T) { w := &writer{} p := &fakePool{ allocSize: 32, errAfter: []int{0}, } o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 0 || !errors.Is(err, errTest) { t.Fatal(n, err) } }) t.Run("no underlying write until full", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if string(w.written) != "1234" { t.Fatal(string(w.written)) } if n, err := b.Write([]byte("789")); n != 3 || err != nil { t.Fatal(n, err) } if string(w.written) != "12345678" { t.Fatal(string(w.written)) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456789" { t.Fatal(string(w.written)) } }) t.Run("auto flush when write larger than buffer", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123456")); err != nil { t.Fatal(n, err) } if string(w.written) != "1234" { t.Fatal(string(w.written)) } }) t.Run("auto flush when multiple smaller writes fill the buffer", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } if n, err := b.Write([]byte("456")); err != nil { t.Fatal(n, err) } if string(w.written) != "1234" { t.Fatal(string(w.written)) } }) t.Run("flush on close", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123" { t.Fatal(string(w.written)) } }) t.Run("manual flush", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } if err := b.Flush(); err != nil { t.Fatal(err) } if string(w.written) != "123" { t.Fatal(string(w.written)) } }) t.Run("write after flush", func(t *testing.T) { w := &writer{} o := buffer.Options{Pool: buffer.NoPool(4)} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if len(w.written) != 0 { t.Fatal(string(w.written)) } if err := b.Flush(); err != nil { t.Fatal(err) } if string(w.written) != "123" { t.Fatal(string(w.written)) } if n, err := b.Write([]byte("456")); err != nil { t.Fatal(n, err) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456" { t.Fatal(string(w.written)) } }) t.Run("zero write recover", func(t *testing.T) { w := &writer{zeroAfter: []int{2}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); err != nil { t.Fatal(n, err) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456" { t.Fatal(string(w.written)) } }) t.Run("zero write terminal", func(t *testing.T) { w := &writer{zeroAfter: []int{2, 2}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 1 || !errors.Is(err, io.ErrShortWrite) { t.Fatal(n, err) } }) t.Run("partial write", func(t *testing.T) { w := &writer{shortAfter: []int{2}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); err != nil { t.Fatal(n, err) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456" { t.Fatal(string(w.written)) } }) t.Run("buffer released on write error immediately", func(t *testing.T) { w := &writer{errAfter: []int{0}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 2 || !errors.Is(err, errTest) { t.Fatal(n, err) } if p.alloc != 1 && p.free != 1 { t.Fatal(p.alloc, p.free) } }) t.Run("buffer released on write error", func(t *testing.T) { w := &writer{errAfter: []int{3}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { t.Fatal(n, err) } if p.alloc != 1 && p.free != 1 { t.Fatal(p.alloc, p.free) } }) t.Run("close on err", func(t *testing.T) { w := &writer{errAfter: []int{3}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("789")); n != 0 || !errors.Is(err, errTest) { t.Fatal(n, err) } if p.alloc != 1 && p.free != 1 { t.Fatal(p.alloc, p.free) } if err := b.Close(); !errors.Is(err, errTest) { t.Fatal(err) } }) t.Run("close and flush err", func(t *testing.T) { w := &writer{errAfter: []int{3}} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if err := b.Close(); !errors.Is(err, errTest) { t.Fatal(err) } }) t.Run("close", func(t *testing.T) { w := &writer{} p := &fakePool{allocSize: 2} o := buffer.Options{Pool: p} b := buffer.BufferedWriter(w, o) if n, err := b.Write([]byte("123")); n != 3 || err != nil { t.Fatal(n, err) } if n, err := b.Write([]byte("456")); n != 3 || err != nil { t.Fatal(n, err) } if err := b.Close(); err != nil { t.Fatal(err) } if string(w.written) != "123456" { t.Fatal(string(w.written)) } }) }