Factor out XDR en/decoding

This commit is contained in:
Jakob Borg 2014-02-15 12:08:55 +01:00
parent 21a7f3960a
commit f89fa6caed
8 changed files with 342 additions and 253 deletions

View File

@ -1,142 +0,0 @@
package protocol
import (
"errors"
"io"
"sync/atomic"
"github.com/calmh/syncthing/buffers"
)
func pad(l int) int {
d := l % 4
if d == 0 {
return 0
}
return 4 - d
}
var padBytes = []byte{0, 0, 0}
type marshalWriter struct {
w io.Writer
tot uint64
err error
b [8]byte
}
// We will never encode nor expect to decode blobs larger than 10 MB. Check
// inserted to protect against attempting to allocate arbitrary amounts of
// memory when reading a corrupt message.
const maxBytesFieldLength = 10 * 1 << 20
var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit")
func (w *marshalWriter) writeString(s string) {
w.writeBytes([]byte(s))
}
func (w *marshalWriter) writeBytes(bs []byte) {
if w.err != nil {
return
}
if len(bs) > maxBytesFieldLength {
w.err = ErrFieldLengthExceeded
return
}
w.writeUint32(uint32(len(bs)))
if w.err != nil {
return
}
_, w.err = w.w.Write(bs)
if p := pad(len(bs)); w.err == nil && p > 0 {
_, w.err = w.w.Write(padBytes[:p])
}
atomic.AddUint64(&w.tot, uint64(len(bs)+pad(len(bs))))
}
func (w *marshalWriter) writeUint32(v uint32) {
if w.err != nil {
return
}
w.b[0] = byte(v >> 24)
w.b[1] = byte(v >> 16)
w.b[2] = byte(v >> 8)
w.b[3] = byte(v)
_, w.err = w.w.Write(w.b[:4])
atomic.AddUint64(&w.tot, 4)
}
func (w *marshalWriter) writeUint64(v uint64) {
if w.err != nil {
return
}
w.b[0] = byte(v >> 56)
w.b[1] = byte(v >> 48)
w.b[2] = byte(v >> 40)
w.b[3] = byte(v >> 32)
w.b[4] = byte(v >> 24)
w.b[5] = byte(v >> 16)
w.b[6] = byte(v >> 8)
w.b[7] = byte(v)
_, w.err = w.w.Write(w.b[:8])
atomic.AddUint64(&w.tot, 8)
}
func (w *marshalWriter) getTot() uint64 {
return atomic.LoadUint64(&w.tot)
}
type marshalReader struct {
r io.Reader
tot uint64
err error
b [8]byte
}
func (r *marshalReader) readString() string {
bs := r.readBytes()
defer buffers.Put(bs)
return string(bs)
}
func (r *marshalReader) readBytes() []byte {
if r.err != nil {
return nil
}
l := int(r.readUint32())
if r.err != nil {
return nil
}
if l > maxBytesFieldLength {
r.err = ErrFieldLengthExceeded
return nil
}
b := buffers.Get(l + pad(l))
_, r.err = io.ReadFull(r.r, b)
atomic.AddUint64(&r.tot, uint64(l+pad(l)))
return b[:l]
}
func (r *marshalReader) readUint32() uint32 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:4])
atomic.AddUint64(&r.tot, 8)
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
}
func (r *marshalReader) readUint64() uint64 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:8])
atomic.AddUint64(&r.tot, 8)
return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
}
func (r *marshalReader) getTot() uint64 {
return atomic.LoadUint64(&r.tot)
}

View File

@ -3,6 +3,9 @@ package protocol
import ( import (
"errors" "errors"
"io" "io"
"github.com/calmh/syncthing/buffers"
"github.com/calmh/syncthing/xdr"
) )
const ( const (
@ -43,60 +46,93 @@ func decodeHeader(u uint32) header {
} }
} }
func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
mw := newMarshalWriter(w)
mw.writeIndex(repo, idx)
return int(mw.Tot()), mw.Err()
}
type marshalWriter struct {
*xdr.Writer
}
func newMarshalWriter(w io.Writer) marshalWriter {
return marshalWriter{xdr.NewWriter(w)}
}
func (w *marshalWriter) writeHeader(h header) { func (w *marshalWriter) writeHeader(h header) {
w.writeUint32(encodeHeader(h)) w.WriteUint32(encodeHeader(h))
} }
func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) { func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) {
w.writeString(repo) w.WriteString(repo)
w.writeUint32(uint32(len(idx))) w.WriteUint32(uint32(len(idx)))
for _, f := range idx { for _, f := range idx {
w.writeString(f.Name) w.WriteString(f.Name)
w.writeUint32(f.Flags) w.WriteUint32(f.Flags)
w.writeUint64(uint64(f.Modified)) w.WriteUint64(uint64(f.Modified))
w.writeUint32(f.Version) w.WriteUint32(f.Version)
w.writeUint32(uint32(len(f.Blocks))) w.WriteUint32(uint32(len(f.Blocks)))
for _, b := range f.Blocks { for _, b := range f.Blocks {
w.writeUint32(b.Size) w.WriteUint32(b.Size)
w.writeBytes(b.Hash) w.WriteBytes(b.Hash)
} }
} }
} }
func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
mw := marshalWriter{w: w}
mw.writeIndex(repo, idx)
return int(mw.getTot()), mw.err
}
func (w *marshalWriter) writeRequest(r request) { func (w *marshalWriter) writeRequest(r request) {
w.writeString(r.repo) w.WriteString(r.repo)
w.writeString(r.name) w.WriteString(r.name)
w.writeUint64(uint64(r.offset)) w.WriteUint64(uint64(r.offset))
w.writeUint32(r.size) w.WriteUint32(r.size)
w.writeBytes(r.hash) w.WriteBytes(r.hash)
} }
func (w *marshalWriter) writeResponse(data []byte) { func (w *marshalWriter) writeResponse(data []byte) {
w.writeBytes(data) w.WriteBytes(data)
} }
func (w *marshalWriter) writeOptions(opts map[string]string) { func (w *marshalWriter) writeOptions(opts map[string]string) {
w.writeUint32(uint32(len(opts))) w.WriteUint32(uint32(len(opts)))
for k, v := range opts { for k, v := range opts {
w.writeString(k) w.WriteString(k)
w.writeString(v) w.WriteString(v)
} }
} }
func (r *marshalReader) readHeader() header { func ReadIndex(r io.Reader) (string, []FileInfo, error) {
return decodeHeader(r.readUint32()) mr := newMarshalReader(r)
repo, idx := mr.readIndex()
return repo, idx, mr.Err()
} }
func (r *marshalReader) readIndex() (string, []FileInfo) { type marshalReader struct {
*xdr.Reader
err error
}
func newMarshalReader(r io.Reader) marshalReader {
return marshalReader{
Reader: xdr.NewReader(r),
err: nil,
}
}
func (r marshalReader) Err() error {
if r.err != nil {
return r.err
}
return r.Reader.Err()
}
func (r marshalReader) readHeader() header {
return decodeHeader(r.ReadUint32())
}
func (r marshalReader) readIndex() (string, []FileInfo) {
var files []FileInfo var files []FileInfo
repo := r.readString() repo := r.ReadString()
nfiles := r.readUint32() nfiles := r.ReadUint32()
if nfiles > maxNumFiles { if nfiles > maxNumFiles {
r.err = ErrMaxFilesExceeded r.err = ErrMaxFilesExceeded
return "", nil return "", nil
@ -104,19 +140,19 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
if nfiles > 0 { if nfiles > 0 {
files = make([]FileInfo, nfiles) files = make([]FileInfo, nfiles)
for i := range files { for i := range files {
files[i].Name = r.readString() files[i].Name = r.ReadString()
files[i].Flags = r.readUint32() files[i].Flags = r.ReadUint32()
files[i].Modified = int64(r.readUint64()) files[i].Modified = int64(r.ReadUint64())
files[i].Version = r.readUint32() files[i].Version = r.ReadUint32()
nblocks := r.readUint32() nblocks := r.ReadUint32()
if nblocks > maxNumBlocks { if nblocks > maxNumBlocks {
r.err = ErrMaxBlocksExceeded r.err = ErrMaxBlocksExceeded
return "", nil return "", nil
} }
blocks := make([]BlockInfo, nblocks) blocks := make([]BlockInfo, nblocks)
for j := range blocks { for j := range blocks {
blocks[j].Size = r.readUint32() blocks[j].Size = r.ReadUint32()
blocks[j].Hash = r.readBytes() blocks[j].Hash = r.ReadBytes(buffers.Get(32))
} }
files[i].Blocks = blocks files[i].Blocks = blocks
} }
@ -124,32 +160,26 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
return repo, files return repo, files
} }
func ReadIndex(r io.Reader) (string, []FileInfo, error) { func (r marshalReader) readRequest() request {
mr := marshalReader{r: r}
repo, idx := mr.readIndex()
return repo, idx, mr.err
}
func (r *marshalReader) readRequest() request {
var req request var req request
req.repo = r.readString() req.repo = r.ReadString()
req.name = r.readString() req.name = r.ReadString()
req.offset = int64(r.readUint64()) req.offset = int64(r.ReadUint64())
req.size = r.readUint32() req.size = r.ReadUint32()
req.hash = r.readBytes() req.hash = r.ReadBytes(buffers.Get(32))
return req return req
} }
func (r *marshalReader) readResponse() []byte { func (r marshalReader) readResponse() []byte {
return r.readBytes() return r.ReadBytes(buffers.Get(128 * 1024))
} }
func (r *marshalReader) readOptions() map[string]string { func (r marshalReader) readOptions() map[string]string {
n := r.readUint32() n := r.ReadUint32()
opts := make(map[string]string, n) opts := make(map[string]string, n)
for i := 0; i < int(n); i++ { for i := 0; i < int(n); i++ {
k := r.readString() k := r.ReadString()
v := r.readString() v := r.ReadString()
opts[k] = v opts[k] = v
} }
return opts return opts

View File

@ -34,10 +34,10 @@ func TestIndex(t *testing.T) {
} }
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
var wr = marshalWriter{w: buf} var wr = newMarshalWriter(buf)
wr.writeIndex("default", idx) wr.writeIndex("default", idx)
var rd = marshalReader{r: buf} var rd = newMarshalReader(buf)
var repo, idx2 = rd.readIndex() var repo, idx2 = rd.readIndex()
if repo != "default" { if repo != "default" {
@ -53,9 +53,9 @@ func TestRequest(t *testing.T) {
f := func(repo, name string, offset int64, size uint32, hash []byte) bool { f := func(repo, name string, offset int64, size uint32, hash []byte) bool {
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
var req = request{repo, name, offset, size, hash} var req = request{repo, name, offset, size, hash}
var wr = marshalWriter{w: buf} var wr = newMarshalWriter(buf)
wr.writeRequest(req) wr.writeRequest(req)
var rd = marshalReader{r: buf} var rd = newMarshalReader(buf)
var req2 = rd.readRequest() var req2 = rd.readRequest()
return req.name == req2.name && return req.name == req2.name &&
req.offset == req2.offset && req.offset == req2.offset &&
@ -70,9 +70,9 @@ func TestRequest(t *testing.T) {
func TestResponse(t *testing.T) { func TestResponse(t *testing.T) {
f := func(data []byte) bool { f := func(data []byte) bool {
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
var wr = marshalWriter{w: buf} var wr = newMarshalWriter(buf)
wr.writeResponse(data) wr.writeResponse(data)
var rd = marshalReader{r: buf} var rd = newMarshalReader(buf)
var read = rd.readResponse() var read = rd.readResponse()
return bytes.Compare(read, data) == 0 return bytes.Compare(read, data) == 0
} }
@ -106,7 +106,7 @@ func BenchmarkWriteIndex(b *testing.B) {
}, },
} }
var wr = marshalWriter{w: ioutil.Discard} var wr = newMarshalWriter(ioutil.Discard)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
wr.writeIndex("default", idx) wr.writeIndex("default", idx)
@ -115,7 +115,7 @@ func BenchmarkWriteIndex(b *testing.B) {
func BenchmarkWriteRequest(b *testing.B) { func BenchmarkWriteRequest(b *testing.B) {
var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")} var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")}
var wr = marshalWriter{w: ioutil.Discard} var wr = newMarshalWriter(ioutil.Discard)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
wr.writeRequest(req) wr.writeRequest(req)
@ -131,10 +131,10 @@ func TestOptions(t *testing.T) {
} }
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
var wr = marshalWriter{w: buf} var wr = newMarshalWriter(buf)
wr.writeOptions(opts) wr.writeOptions(opts)
var rd = marshalReader{r: buf} var rd = newMarshalReader(buf)
var ropts = rd.readOptions() var ropts = rd.readOptions()
if !reflect.DeepEqual(opts, ropts) { if !reflect.DeepEqual(opts, ropts) {

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/calmh/syncthing/buffers" "github.com/calmh/syncthing/buffers"
"github.com/calmh/syncthing/xdr"
) )
const ( const (
@ -61,9 +62,9 @@ type Connection struct {
id string id string
receiver Model receiver Model
reader io.Reader reader io.Reader
mreader *marshalReader mreader marshalReader
writer io.Writer writer io.Writer
mwriter *marshalWriter mwriter marshalWriter
closed bool closed bool
awaiting map[int]chan asyncResult awaiting map[int]chan asyncResult
nextId int nextId int
@ -101,9 +102,9 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
id: nodeID, id: nodeID,
receiver: receiver, receiver: receiver,
reader: flrd, reader: flrd,
mreader: &marshalReader{r: flrd}, mreader: marshalReader{Reader: xdr.NewReader(flrd)},
writer: flwr, writer: flwr,
mwriter: &marshalWriter{w: flwr}, mwriter: marshalWriter{Writer: xdr.NewWriter(flwr)},
awaiting: make(map[int]chan asyncResult), awaiting: make(map[int]chan asyncResult),
indexSent: make(map[string]map[string][2]int64), indexSent: make(map[string]map[string][2]int64),
} }
@ -168,8 +169,8 @@ func (c *Connection) Index(repo string, idx []FileInfo) {
if err != nil { if err != nil {
c.close(err) c.close(err)
return return
} else if c.mwriter.err != nil { } else if c.mwriter.Err() != nil {
c.close(c.mwriter.err) c.close(c.mwriter.Err())
return return
} }
} }
@ -185,10 +186,10 @@ func (c *Connection) Request(repo string, name string, offset int64, size uint32
c.awaiting[c.nextId] = rc c.awaiting[c.nextId] = rc
c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
c.mwriter.writeRequest(request{repo, name, offset, size, hash}) c.mwriter.writeRequest(request{repo, name, offset, size, hash})
if c.mwriter.err != nil { if c.mwriter.Err() != nil {
c.Unlock() c.Unlock()
c.close(c.mwriter.err) c.close(c.mwriter.Err())
return nil, c.mwriter.err return nil, c.mwriter.Err()
} }
err := c.flush() err := c.flush()
if err != nil { if err != nil {
@ -220,9 +221,9 @@ func (c *Connection) ping() bool {
c.Unlock() c.Unlock()
c.close(err) c.close(err)
return false return false
} else if c.mwriter.err != nil { } else if c.mwriter.Err() != nil {
c.Unlock() c.Unlock()
c.close(c.mwriter.err) c.close(c.mwriter.Err())
return false return false
} }
c.nextId = (c.nextId + 1) & 0xfff c.nextId = (c.nextId + 1) & 0xfff
@ -269,8 +270,8 @@ func (c *Connection) readerLoop() {
loop: loop:
for { for {
hdr := c.mreader.readHeader() hdr := c.mreader.readHeader()
if c.mreader.err != nil { if c.mreader.Err() != nil {
c.close(c.mreader.err) c.close(c.mreader.Err())
break loop break loop
} }
if hdr.version != 0 { if hdr.version != 0 {
@ -282,8 +283,8 @@ loop:
case messageTypeIndex: case messageTypeIndex:
repo, files := c.mreader.readIndex() repo, files := c.mreader.readIndex()
_ = repo _ = repo
if c.mreader.err != nil { if c.mreader.Err() != nil {
c.close(c.mreader.err) c.close(c.mreader.Err())
break loop break loop
} else { } else {
c.receiver.Index(c.id, files) c.receiver.Index(c.id, files)
@ -295,8 +296,8 @@ loop:
case messageTypeIndexUpdate: case messageTypeIndexUpdate:
repo, files := c.mreader.readIndex() repo, files := c.mreader.readIndex()
_ = repo _ = repo
if c.mreader.err != nil { if c.mreader.Err() != nil {
c.close(c.mreader.err) c.close(c.mreader.Err())
break loop break loop
} else { } else {
c.receiver.IndexUpdate(c.id, files) c.receiver.IndexUpdate(c.id, files)
@ -304,8 +305,8 @@ loop:
case messageTypeRequest: case messageTypeRequest:
req := c.mreader.readRequest() req := c.mreader.readRequest()
if c.mreader.err != nil { if c.mreader.Err() != nil {
c.close(c.mreader.err) c.close(c.mreader.Err())
break loop break loop
} }
go c.processRequest(hdr.msgID, req) go c.processRequest(hdr.msgID, req)
@ -313,8 +314,8 @@ loop:
case messageTypeResponse: case messageTypeResponse:
data := c.mreader.readResponse() data := c.mreader.readResponse()
if c.mreader.err != nil { if c.mreader.Err() != nil {
c.close(c.mreader.err) c.close(c.mreader.Err())
break loop break loop
} else { } else {
c.Lock() c.Lock()
@ -323,21 +324,21 @@ loop:
c.Unlock() c.Unlock()
if ok { if ok {
rc <- asyncResult{data, c.mreader.err} rc <- asyncResult{data, c.mreader.Err()}
close(rc) close(rc)
} }
} }
case messageTypePing: case messageTypePing:
c.Lock() c.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) c.mwriter.WriteUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
err := c.flush() err := c.flush()
c.Unlock() c.Unlock()
if err != nil { if err != nil {
c.close(err) c.close(err)
break loop break loop
} else if c.mwriter.err != nil { } else if c.mwriter.Err() != nil {
c.close(c.mwriter.err) c.close(c.mwriter.Err())
break loop break loop
} }
@ -376,9 +377,9 @@ func (c *Connection) processRequest(msgID int, req request) {
data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash) data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash)
c.Lock() c.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse})) c.mwriter.WriteUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
c.mwriter.writeResponse(data) c.mwriter.writeResponse(data)
err := c.mwriter.err err := c.mwriter.Err()
if err == nil { if err == nil {
err = c.flush() err = c.flush()
} }
@ -427,8 +428,8 @@ func (c *Connection) Statistics() Statistics {
stats := Statistics{ stats := Statistics{
At: time.Now(), At: time.Now(),
InBytesTotal: int(c.mreader.getTot()), InBytesTotal: int(c.mreader.Tot()),
OutBytesTotal: int(c.mwriter.getTot()), OutBytesTotal: int(c.mwriter.Tot()),
} }
return stats return stats

View File

@ -22,23 +22,6 @@ func TestHeaderFunctions(t *testing.T) {
} }
} }
func TestPad(t *testing.T) {
tests := [][]int{
{0, 0},
{1, 3},
{2, 2},
{3, 1},
{4, 0},
{32, 0},
{33, 3},
}
for _, tc := range tests {
if p := pad(tc[0]); p != tc[1] {
t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
}
}
}
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
ar, aw := io.Pipe() ar, aw := io.Pipe()
br, bw := io.Pipe() br, bw := io.Pipe()

65
xdr/reader.go Normal file
View File

@ -0,0 +1,65 @@
package xdr
import "io"
type Reader struct {
r io.Reader
tot uint64
err error
b [8]byte
}
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
}
}
func (r *Reader) ReadString() string {
return string(r.ReadBytes(nil))
}
func (r *Reader) ReadBytes(dst []byte) []byte {
if r.err != nil {
return nil
}
l := int(r.ReadUint32())
if r.err != nil {
return nil
}
if l+pad(l) > len(dst) {
dst = make([]byte, l+pad(l))
} else {
dst = dst[:l+pad(l)]
}
_, r.err = io.ReadFull(r.r, dst)
r.tot += uint64(l + pad(l))
return dst[:l]
}
func (r *Reader) ReadUint32() uint32 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:4])
r.tot += 8
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
}
func (r *Reader) ReadUint64() uint64 {
if r.err != nil {
return 0
}
_, r.err = io.ReadFull(r.r, r.b[:8])
r.tot += 8
return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
}
func (r *Reader) Tot() uint64 {
return r.tot
}
func (r *Reader) Err() error {
return r.err
}

95
xdr/writer.go Normal file
View File

@ -0,0 +1,95 @@
package xdr
import "io"
func pad(l int) int {
d := l % 4
if d == 0 {
return 0
}
return 4 - d
}
var padBytes = []byte{0, 0, 0}
type Writer struct {
w io.Writer
tot uint64
err error
b [8]byte
}
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
}
}
func (w *Writer) WriteString(s string) (int, error) {
return w.WriteBytes([]byte(s))
}
func (w *Writer) WriteBytes(bs []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
w.WriteUint32(uint32(len(bs)))
if w.err != nil {
return 0, w.err
}
var l, n int
n, w.err = w.w.Write(bs)
l += n
if p := pad(len(bs)); w.err == nil && p > 0 {
n, w.err = w.w.Write(padBytes[:p])
l += n
}
w.tot += uint64(l)
return l, w.err
}
func (w *Writer) WriteUint32(v uint32) (int, error) {
if w.err != nil {
return 0, w.err
}
w.b[0] = byte(v >> 24)
w.b[1] = byte(v >> 16)
w.b[2] = byte(v >> 8)
w.b[3] = byte(v)
var l int
l, w.err = w.w.Write(w.b[:4])
w.tot += uint64(l)
return l, w.err
}
func (w *Writer) WriteUint64(v uint64) (int, error) {
if w.err != nil {
return 0, w.err
}
w.b[0] = byte(v >> 56)
w.b[1] = byte(v >> 48)
w.b[2] = byte(v >> 40)
w.b[3] = byte(v >> 32)
w.b[4] = byte(v >> 24)
w.b[5] = byte(v >> 16)
w.b[6] = byte(v >> 8)
w.b[7] = byte(v)
var l int
l, w.err = w.w.Write(w.b[:8])
w.tot += uint64(l)
return l, w.err
}
func (w *Writer) Tot() uint64 {
return w.tot
}
func (w *Writer) Err() error {
return w.err
}

57
xdr/xdr_test.go Normal file
View File

@ -0,0 +1,57 @@
package xdr
import (
"bytes"
"testing"
"testing/quick"
)
func TestPad(t *testing.T) {
tests := [][]int{
{0, 0},
{1, 3},
{2, 2},
{3, 1},
{4, 0},
{32, 0},
{33, 3},
}
for _, tc := range tests {
if p := pad(tc[0]); p != tc[1] {
t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
}
}
}
func TestBytesNil(t *testing.T) {
fn := func(bs []byte) bool {
var b = new(bytes.Buffer)
var w = NewWriter(b)
var r = NewReader(b)
w.WriteBytes(bs)
w.WriteBytes(bs)
r.ReadBytes(nil)
res := r.ReadBytes(nil)
return bytes.Compare(bs, res) == 0
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}
func TestBytesGiven(t *testing.T) {
fn := func(bs []byte) bool {
var b = new(bytes.Buffer)
var w = NewWriter(b)
var r = NewReader(b)
w.WriteBytes(bs)
w.WriteBytes(bs)
res := make([]byte, 12)
res = r.ReadBytes(res)
res = r.ReadBytes(res)
return bytes.Compare(bs, res) == 0
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}