From f89fa6caedc491bb4b735633e04b2b52a023fc51 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Sat, 15 Feb 2014 12:08:55 +0100 Subject: [PATCH] Factor out XDR en/decoding --- protocol/marshal.go | 142 -------------------------------------- protocol/messages.go | 140 ++++++++++++++++++++++--------------- protocol/messages_test.go | 20 +++--- protocol/protocol.go | 59 ++++++++-------- protocol/protocol_test.go | 17 ----- xdr/reader.go | 65 +++++++++++++++++ xdr/writer.go | 95 +++++++++++++++++++++++++ xdr/xdr_test.go | 57 +++++++++++++++ 8 files changed, 342 insertions(+), 253 deletions(-) delete mode 100644 protocol/marshal.go create mode 100644 xdr/reader.go create mode 100644 xdr/writer.go create mode 100644 xdr/xdr_test.go diff --git a/protocol/marshal.go b/protocol/marshal.go deleted file mode 100644 index c1012a1e0..000000000 --- a/protocol/marshal.go +++ /dev/null @@ -1,142 +0,0 @@ -package protocol - -import ( - "errors" - "io" - "sync/atomic" - - "github.com/calmh/syncthing/buffers" -) - -func pad(l int) int { - d := l % 4 - if d == 0 { - return 0 - } - return 4 - d -} - -var padBytes = []byte{0, 0, 0} - -type marshalWriter struct { - w io.Writer - tot uint64 - err error - b [8]byte -} - -// We will never encode nor expect to decode blobs larger than 10 MB. Check -// inserted to protect against attempting to allocate arbitrary amounts of -// memory when reading a corrupt message. -const maxBytesFieldLength = 10 * 1 << 20 - -var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit") - -func (w *marshalWriter) writeString(s string) { - w.writeBytes([]byte(s)) -} - -func (w *marshalWriter) writeBytes(bs []byte) { - if w.err != nil { - return - } - if len(bs) > maxBytesFieldLength { - w.err = ErrFieldLengthExceeded - return - } - w.writeUint32(uint32(len(bs))) - if w.err != nil { - return - } - _, w.err = w.w.Write(bs) - if p := pad(len(bs)); w.err == nil && p > 0 { - _, w.err = w.w.Write(padBytes[:p]) - } - atomic.AddUint64(&w.tot, uint64(len(bs)+pad(len(bs)))) -} - -func (w *marshalWriter) writeUint32(v uint32) { - if w.err != nil { - return - } - w.b[0] = byte(v >> 24) - w.b[1] = byte(v >> 16) - w.b[2] = byte(v >> 8) - w.b[3] = byte(v) - _, w.err = w.w.Write(w.b[:4]) - atomic.AddUint64(&w.tot, 4) -} - -func (w *marshalWriter) writeUint64(v uint64) { - if w.err != nil { - return - } - w.b[0] = byte(v >> 56) - w.b[1] = byte(v >> 48) - w.b[2] = byte(v >> 40) - w.b[3] = byte(v >> 32) - w.b[4] = byte(v >> 24) - w.b[5] = byte(v >> 16) - w.b[6] = byte(v >> 8) - w.b[7] = byte(v) - _, w.err = w.w.Write(w.b[:8]) - atomic.AddUint64(&w.tot, 8) -} - -func (w *marshalWriter) getTot() uint64 { - return atomic.LoadUint64(&w.tot) -} - -type marshalReader struct { - r io.Reader - tot uint64 - err error - b [8]byte -} - -func (r *marshalReader) readString() string { - bs := r.readBytes() - defer buffers.Put(bs) - return string(bs) -} - -func (r *marshalReader) readBytes() []byte { - if r.err != nil { - return nil - } - l := int(r.readUint32()) - if r.err != nil { - return nil - } - if l > maxBytesFieldLength { - r.err = ErrFieldLengthExceeded - return nil - } - b := buffers.Get(l + pad(l)) - _, r.err = io.ReadFull(r.r, b) - atomic.AddUint64(&r.tot, uint64(l+pad(l))) - return b[:l] -} - -func (r *marshalReader) readUint32() uint32 { - if r.err != nil { - return 0 - } - _, r.err = io.ReadFull(r.r, r.b[:4]) - atomic.AddUint64(&r.tot, 8) - return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24 -} - -func (r *marshalReader) readUint64() uint64 { - if r.err != nil { - return 0 - } - _, r.err = io.ReadFull(r.r, r.b[:8]) - atomic.AddUint64(&r.tot, 8) - return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 | - uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56 -} - -func (r *marshalReader) getTot() uint64 { - return atomic.LoadUint64(&r.tot) -} diff --git a/protocol/messages.go b/protocol/messages.go index cfeb8c63f..39067b358 100644 --- a/protocol/messages.go +++ b/protocol/messages.go @@ -3,6 +3,9 @@ package protocol import ( "errors" "io" + + "github.com/calmh/syncthing/buffers" + "github.com/calmh/syncthing/xdr" ) const ( @@ -43,60 +46,93 @@ func decodeHeader(u uint32) header { } } +func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) { + mw := newMarshalWriter(w) + mw.writeIndex(repo, idx) + return int(mw.Tot()), mw.Err() +} + +type marshalWriter struct { + *xdr.Writer +} + +func newMarshalWriter(w io.Writer) marshalWriter { + return marshalWriter{xdr.NewWriter(w)} +} + func (w *marshalWriter) writeHeader(h header) { - w.writeUint32(encodeHeader(h)) + w.WriteUint32(encodeHeader(h)) } func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) { - w.writeString(repo) - w.writeUint32(uint32(len(idx))) + w.WriteString(repo) + w.WriteUint32(uint32(len(idx))) for _, f := range idx { - w.writeString(f.Name) - w.writeUint32(f.Flags) - w.writeUint64(uint64(f.Modified)) - w.writeUint32(f.Version) - w.writeUint32(uint32(len(f.Blocks))) + w.WriteString(f.Name) + w.WriteUint32(f.Flags) + w.WriteUint64(uint64(f.Modified)) + w.WriteUint32(f.Version) + w.WriteUint32(uint32(len(f.Blocks))) for _, b := range f.Blocks { - w.writeUint32(b.Size) - w.writeBytes(b.Hash) + w.WriteUint32(b.Size) + w.WriteBytes(b.Hash) } } } -func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) { - mw := marshalWriter{w: w} - mw.writeIndex(repo, idx) - return int(mw.getTot()), mw.err -} - func (w *marshalWriter) writeRequest(r request) { - w.writeString(r.repo) - w.writeString(r.name) - w.writeUint64(uint64(r.offset)) - w.writeUint32(r.size) - w.writeBytes(r.hash) + w.WriteString(r.repo) + w.WriteString(r.name) + w.WriteUint64(uint64(r.offset)) + w.WriteUint32(r.size) + w.WriteBytes(r.hash) } func (w *marshalWriter) writeResponse(data []byte) { - w.writeBytes(data) + w.WriteBytes(data) } func (w *marshalWriter) writeOptions(opts map[string]string) { - w.writeUint32(uint32(len(opts))) + w.WriteUint32(uint32(len(opts))) for k, v := range opts { - w.writeString(k) - w.writeString(v) + w.WriteString(k) + w.WriteString(v) } } -func (r *marshalReader) readHeader() header { - return decodeHeader(r.readUint32()) +func ReadIndex(r io.Reader) (string, []FileInfo, error) { + mr := newMarshalReader(r) + repo, idx := mr.readIndex() + return repo, idx, mr.Err() } -func (r *marshalReader) readIndex() (string, []FileInfo) { +type marshalReader struct { + *xdr.Reader + err error +} + +func newMarshalReader(r io.Reader) marshalReader { + return marshalReader{ + Reader: xdr.NewReader(r), + err: nil, + } +} + +func (r marshalReader) Err() error { + if r.err != nil { + return r.err + } + return r.Reader.Err() +} + +func (r marshalReader) readHeader() header { + return decodeHeader(r.ReadUint32()) +} + +func (r marshalReader) readIndex() (string, []FileInfo) { var files []FileInfo - repo := r.readString() - nfiles := r.readUint32() + repo := r.ReadString() + nfiles := r.ReadUint32() if nfiles > maxNumFiles { r.err = ErrMaxFilesExceeded return "", nil @@ -104,19 +140,19 @@ func (r *marshalReader) readIndex() (string, []FileInfo) { if nfiles > 0 { files = make([]FileInfo, nfiles) for i := range files { - files[i].Name = r.readString() - files[i].Flags = r.readUint32() - files[i].Modified = int64(r.readUint64()) - files[i].Version = r.readUint32() - nblocks := r.readUint32() + files[i].Name = r.ReadString() + files[i].Flags = r.ReadUint32() + files[i].Modified = int64(r.ReadUint64()) + files[i].Version = r.ReadUint32() + nblocks := r.ReadUint32() if nblocks > maxNumBlocks { r.err = ErrMaxBlocksExceeded return "", nil } blocks := make([]BlockInfo, nblocks) for j := range blocks { - blocks[j].Size = r.readUint32() - blocks[j].Hash = r.readBytes() + blocks[j].Size = r.ReadUint32() + blocks[j].Hash = r.ReadBytes(buffers.Get(32)) } files[i].Blocks = blocks } @@ -124,32 +160,26 @@ func (r *marshalReader) readIndex() (string, []FileInfo) { return repo, files } -func ReadIndex(r io.Reader) (string, []FileInfo, error) { - mr := marshalReader{r: r} - repo, idx := mr.readIndex() - return repo, idx, mr.err -} - -func (r *marshalReader) readRequest() request { +func (r marshalReader) readRequest() request { var req request - req.repo = r.readString() - req.name = r.readString() - req.offset = int64(r.readUint64()) - req.size = r.readUint32() - req.hash = r.readBytes() + req.repo = r.ReadString() + req.name = r.ReadString() + req.offset = int64(r.ReadUint64()) + req.size = r.ReadUint32() + req.hash = r.ReadBytes(buffers.Get(32)) return req } -func (r *marshalReader) readResponse() []byte { - return r.readBytes() +func (r marshalReader) readResponse() []byte { + return r.ReadBytes(buffers.Get(128 * 1024)) } -func (r *marshalReader) readOptions() map[string]string { - n := r.readUint32() +func (r marshalReader) readOptions() map[string]string { + n := r.ReadUint32() opts := make(map[string]string, n) for i := 0; i < int(n); i++ { - k := r.readString() - v := r.readString() + k := r.ReadString() + v := r.ReadString() opts[k] = v } return opts diff --git a/protocol/messages_test.go b/protocol/messages_test.go index de3c90755..968d1b292 100644 --- a/protocol/messages_test.go +++ b/protocol/messages_test.go @@ -34,10 +34,10 @@ func TestIndex(t *testing.T) { } var buf = new(bytes.Buffer) - var wr = marshalWriter{w: buf} + var wr = newMarshalWriter(buf) wr.writeIndex("default", idx) - var rd = marshalReader{r: buf} + var rd = newMarshalReader(buf) var repo, idx2 = rd.readIndex() if repo != "default" { @@ -53,9 +53,9 @@ func TestRequest(t *testing.T) { f := func(repo, name string, offset int64, size uint32, hash []byte) bool { var buf = new(bytes.Buffer) var req = request{repo, name, offset, size, hash} - var wr = marshalWriter{w: buf} + var wr = newMarshalWriter(buf) wr.writeRequest(req) - var rd = marshalReader{r: buf} + var rd = newMarshalReader(buf) var req2 = rd.readRequest() return req.name == req2.name && req.offset == req2.offset && @@ -70,9 +70,9 @@ func TestRequest(t *testing.T) { func TestResponse(t *testing.T) { f := func(data []byte) bool { var buf = new(bytes.Buffer) - var wr = marshalWriter{w: buf} + var wr = newMarshalWriter(buf) wr.writeResponse(data) - var rd = marshalReader{r: buf} + var rd = newMarshalReader(buf) var read = rd.readResponse() return bytes.Compare(read, data) == 0 } @@ -106,7 +106,7 @@ func BenchmarkWriteIndex(b *testing.B) { }, } - var wr = marshalWriter{w: ioutil.Discard} + var wr = newMarshalWriter(ioutil.Discard) for i := 0; i < b.N; i++ { wr.writeIndex("default", idx) @@ -115,7 +115,7 @@ func BenchmarkWriteIndex(b *testing.B) { func BenchmarkWriteRequest(b *testing.B) { var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")} - var wr = marshalWriter{w: ioutil.Discard} + var wr = newMarshalWriter(ioutil.Discard) for i := 0; i < b.N; i++ { wr.writeRequest(req) @@ -131,10 +131,10 @@ func TestOptions(t *testing.T) { } var buf = new(bytes.Buffer) - var wr = marshalWriter{w: buf} + var wr = newMarshalWriter(buf) wr.writeOptions(opts) - var rd = marshalReader{r: buf} + var rd = newMarshalReader(buf) var ropts = rd.readOptions() if !reflect.DeepEqual(opts, ropts) { diff --git a/protocol/protocol.go b/protocol/protocol.go index 62ce8b9c1..206abdcc7 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -10,6 +10,7 @@ import ( "time" "github.com/calmh/syncthing/buffers" + "github.com/calmh/syncthing/xdr" ) const ( @@ -61,9 +62,9 @@ type Connection struct { id string receiver Model reader io.Reader - mreader *marshalReader + mreader marshalReader writer io.Writer - mwriter *marshalWriter + mwriter marshalWriter closed bool awaiting map[int]chan asyncResult nextId int @@ -101,9 +102,9 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M id: nodeID, receiver: receiver, reader: flrd, - mreader: &marshalReader{r: flrd}, + mreader: marshalReader{Reader: xdr.NewReader(flrd)}, writer: flwr, - mwriter: &marshalWriter{w: flwr}, + mwriter: marshalWriter{Writer: xdr.NewWriter(flwr)}, awaiting: make(map[int]chan asyncResult), indexSent: make(map[string]map[string][2]int64), } @@ -168,8 +169,8 @@ func (c *Connection) Index(repo string, idx []FileInfo) { if err != nil { c.close(err) return - } else if c.mwriter.err != nil { - c.close(c.mwriter.err) + } else if c.mwriter.Err() != nil { + c.close(c.mwriter.Err()) return } } @@ -185,10 +186,10 @@ func (c *Connection) Request(repo string, name string, offset int64, size uint32 c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) c.mwriter.writeRequest(request{repo, name, offset, size, hash}) - if c.mwriter.err != nil { + if c.mwriter.Err() != nil { c.Unlock() - c.close(c.mwriter.err) - return nil, c.mwriter.err + c.close(c.mwriter.Err()) + return nil, c.mwriter.Err() } err := c.flush() if err != nil { @@ -220,9 +221,9 @@ func (c *Connection) ping() bool { c.Unlock() c.close(err) return false - } else if c.mwriter.err != nil { + } else if c.mwriter.Err() != nil { c.Unlock() - c.close(c.mwriter.err) + c.close(c.mwriter.Err()) return false } c.nextId = (c.nextId + 1) & 0xfff @@ -269,8 +270,8 @@ func (c *Connection) readerLoop() { loop: for { hdr := c.mreader.readHeader() - if c.mreader.err != nil { - c.close(c.mreader.err) + if c.mreader.Err() != nil { + c.close(c.mreader.Err()) break loop } if hdr.version != 0 { @@ -282,8 +283,8 @@ loop: case messageTypeIndex: repo, files := c.mreader.readIndex() _ = repo - if c.mreader.err != nil { - c.close(c.mreader.err) + if c.mreader.Err() != nil { + c.close(c.mreader.Err()) break loop } else { c.receiver.Index(c.id, files) @@ -295,8 +296,8 @@ loop: case messageTypeIndexUpdate: repo, files := c.mreader.readIndex() _ = repo - if c.mreader.err != nil { - c.close(c.mreader.err) + if c.mreader.Err() != nil { + c.close(c.mreader.Err()) break loop } else { c.receiver.IndexUpdate(c.id, files) @@ -304,8 +305,8 @@ loop: case messageTypeRequest: req := c.mreader.readRequest() - if c.mreader.err != nil { - c.close(c.mreader.err) + if c.mreader.Err() != nil { + c.close(c.mreader.Err()) break loop } go c.processRequest(hdr.msgID, req) @@ -313,8 +314,8 @@ loop: case messageTypeResponse: data := c.mreader.readResponse() - if c.mreader.err != nil { - c.close(c.mreader.err) + if c.mreader.Err() != nil { + c.close(c.mreader.Err()) break loop } else { c.Lock() @@ -323,21 +324,21 @@ loop: c.Unlock() if ok { - rc <- asyncResult{data, c.mreader.err} + rc <- asyncResult{data, c.mreader.Err()} close(rc) } } case messageTypePing: c.Lock() - c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) + c.mwriter.WriteUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) err := c.flush() c.Unlock() if err != nil { c.close(err) break loop - } else if c.mwriter.err != nil { - c.close(c.mwriter.err) + } else if c.mwriter.Err() != nil { + c.close(c.mwriter.Err()) break loop } @@ -376,9 +377,9 @@ func (c *Connection) processRequest(msgID int, req request) { data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash) c.Lock() - c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse})) + c.mwriter.WriteUint32(encodeHeader(header{0, msgID, messageTypeResponse})) c.mwriter.writeResponse(data) - err := c.mwriter.err + err := c.mwriter.Err() if err == nil { err = c.flush() } @@ -427,8 +428,8 @@ func (c *Connection) Statistics() Statistics { stats := Statistics{ At: time.Now(), - InBytesTotal: int(c.mreader.getTot()), - OutBytesTotal: int(c.mwriter.getTot()), + InBytesTotal: int(c.mreader.Tot()), + OutBytesTotal: int(c.mwriter.Tot()), } return stats diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 08e41c49e..0f49c01dc 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -22,23 +22,6 @@ func TestHeaderFunctions(t *testing.T) { } } -func TestPad(t *testing.T) { - tests := [][]int{ - {0, 0}, - {1, 3}, - {2, 2}, - {3, 1}, - {4, 0}, - {32, 0}, - {33, 3}, - } - for _, tc := range tests { - if p := pad(tc[0]); p != tc[1] { - t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1]) - } - } -} - func TestPing(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() diff --git a/xdr/reader.go b/xdr/reader.go new file mode 100644 index 000000000..b5c39a6b2 --- /dev/null +++ b/xdr/reader.go @@ -0,0 +1,65 @@ +package xdr + +import "io" + +type Reader struct { + r io.Reader + tot uint64 + err error + b [8]byte +} + +func NewReader(r io.Reader) *Reader { + return &Reader{ + r: r, + } +} + +func (r *Reader) ReadString() string { + return string(r.ReadBytes(nil)) +} + +func (r *Reader) ReadBytes(dst []byte) []byte { + if r.err != nil { + return nil + } + l := int(r.ReadUint32()) + if r.err != nil { + return nil + } + if l+pad(l) > len(dst) { + dst = make([]byte, l+pad(l)) + } else { + dst = dst[:l+pad(l)] + } + _, r.err = io.ReadFull(r.r, dst) + r.tot += uint64(l + pad(l)) + return dst[:l] +} + +func (r *Reader) ReadUint32() uint32 { + if r.err != nil { + return 0 + } + _, r.err = io.ReadFull(r.r, r.b[:4]) + r.tot += 8 + return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24 +} + +func (r *Reader) ReadUint64() uint64 { + if r.err != nil { + return 0 + } + _, r.err = io.ReadFull(r.r, r.b[:8]) + r.tot += 8 + return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 | + uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56 +} + +func (r *Reader) Tot() uint64 { + return r.tot +} + +func (r *Reader) Err() error { + return r.err +} diff --git a/xdr/writer.go b/xdr/writer.go new file mode 100644 index 000000000..30c7c5639 --- /dev/null +++ b/xdr/writer.go @@ -0,0 +1,95 @@ +package xdr + +import "io" + +func pad(l int) int { + d := l % 4 + if d == 0 { + return 0 + } + return 4 - d +} + +var padBytes = []byte{0, 0, 0} + +type Writer struct { + w io.Writer + tot uint64 + err error + b [8]byte +} + +func NewWriter(w io.Writer) *Writer { + return &Writer{ + w: w, + } +} + +func (w *Writer) WriteString(s string) (int, error) { + return w.WriteBytes([]byte(s)) +} + +func (w *Writer) WriteBytes(bs []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + w.WriteUint32(uint32(len(bs))) + if w.err != nil { + return 0, w.err + } + + var l, n int + n, w.err = w.w.Write(bs) + l += n + + if p := pad(len(bs)); w.err == nil && p > 0 { + n, w.err = w.w.Write(padBytes[:p]) + l += n + } + + w.tot += uint64(l) + return l, w.err +} + +func (w *Writer) WriteUint32(v uint32) (int, error) { + if w.err != nil { + return 0, w.err + } + w.b[0] = byte(v >> 24) + w.b[1] = byte(v >> 16) + w.b[2] = byte(v >> 8) + w.b[3] = byte(v) + + var l int + l, w.err = w.w.Write(w.b[:4]) + w.tot += uint64(l) + return l, w.err +} + +func (w *Writer) WriteUint64(v uint64) (int, error) { + if w.err != nil { + return 0, w.err + } + w.b[0] = byte(v >> 56) + w.b[1] = byte(v >> 48) + w.b[2] = byte(v >> 40) + w.b[3] = byte(v >> 32) + w.b[4] = byte(v >> 24) + w.b[5] = byte(v >> 16) + w.b[6] = byte(v >> 8) + w.b[7] = byte(v) + + var l int + l, w.err = w.w.Write(w.b[:8]) + w.tot += uint64(l) + return l, w.err +} + +func (w *Writer) Tot() uint64 { + return w.tot +} + +func (w *Writer) Err() error { + return w.err +} diff --git a/xdr/xdr_test.go b/xdr/xdr_test.go new file mode 100644 index 000000000..859958ef8 --- /dev/null +++ b/xdr/xdr_test.go @@ -0,0 +1,57 @@ +package xdr + +import ( + "bytes" + "testing" + "testing/quick" +) + +func TestPad(t *testing.T) { + tests := [][]int{ + {0, 0}, + {1, 3}, + {2, 2}, + {3, 1}, + {4, 0}, + {32, 0}, + {33, 3}, + } + for _, tc := range tests { + if p := pad(tc[0]); p != tc[1] { + t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1]) + } + } +} + +func TestBytesNil(t *testing.T) { + fn := func(bs []byte) bool { + var b = new(bytes.Buffer) + var w = NewWriter(b) + var r = NewReader(b) + w.WriteBytes(bs) + w.WriteBytes(bs) + r.ReadBytes(nil) + res := r.ReadBytes(nil) + return bytes.Compare(bs, res) == 0 + } + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + +func TestBytesGiven(t *testing.T) { + fn := func(bs []byte) bool { + var b = new(bytes.Buffer) + var w = NewWriter(b) + var r = NewReader(b) + w.WriteBytes(bs) + w.WriteBytes(bs) + res := make([]byte, 12) + res = r.ReadBytes(res) + res = r.ReadBytes(res) + return bytes.Compare(bs, res) == 0 + } + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +}