diff --git a/message_xdr.go b/message_xdr.go index 324125ea4..948e63c32 100644 --- a/message_xdr.go +++ b/message_xdr.go @@ -44,20 +44,28 @@ func (o IndexMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o IndexMessage) MarshalXDR() []byte { +func (o IndexMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o IndexMessage) AppendXDR(bs []byte) []byte { +func (o IndexMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o IndexMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o IndexMessage) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Folder) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Folder); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Folder", l, 64) } xw.WriteString(o.Folder) xw.WriteUint32(uint32(len(o.Files))) @@ -142,20 +150,28 @@ func (o FileInfo) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o FileInfo) MarshalXDR() []byte { +func (o FileInfo) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o FileInfo) AppendXDR(bs []byte) []byte { +func (o FileInfo) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o FileInfo) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o FileInfo) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Name) > 8192 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Name); l > 8192 { + return xw.Tot(), xdr.ElementSizeExceeded("Name", l, 8192) } xw.WriteString(o.Name) xw.WriteUint32(o.Flags) @@ -244,20 +260,28 @@ func (o FileInfoTruncated) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o FileInfoTruncated) MarshalXDR() []byte { +func (o FileInfoTruncated) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o FileInfoTruncated) AppendXDR(bs []byte) []byte { +func (o FileInfoTruncated) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o FileInfoTruncated) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o FileInfoTruncated) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Name) > 8192 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Name); l > 8192 { + return xw.Tot(), xdr.ElementSizeExceeded("Name", l, 8192) } xw.WriteString(o.Name) xw.WriteUint32(o.Flags) @@ -318,21 +342,29 @@ func (o BlockInfo) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o BlockInfo) MarshalXDR() []byte { +func (o BlockInfo) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o BlockInfo) AppendXDR(bs []byte) []byte { +func (o BlockInfo) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o BlockInfo) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o BlockInfo) encodeXDR(xw *xdr.Writer) (int, error) { xw.WriteUint32(o.Size) - if len(o.Hash) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Hash); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Hash", l, 64) } xw.WriteBytes(o.Hash) return xw.Tot(), xw.Error() @@ -396,24 +428,32 @@ func (o RequestMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o RequestMessage) MarshalXDR() []byte { +func (o RequestMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o RequestMessage) AppendXDR(bs []byte) []byte { +func (o RequestMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o RequestMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o RequestMessage) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Folder) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Folder); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Folder", l, 64) } xw.WriteString(o.Folder) - if len(o.Name) > 8192 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Name); l > 8192 { + return xw.Tot(), xdr.ElementSizeExceeded("Name", l, 8192) } xw.WriteString(o.Name) xw.WriteUint64(o.Offset) @@ -466,15 +506,23 @@ func (o ResponseMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o ResponseMessage) MarshalXDR() []byte { +func (o ResponseMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o ResponseMessage) AppendXDR(bs []byte) []byte { +func (o ResponseMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o ResponseMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o ResponseMessage) encodeXDR(xw *xdr.Writer) (int, error) { @@ -545,28 +593,36 @@ func (o ClusterConfigMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o ClusterConfigMessage) MarshalXDR() []byte { +func (o ClusterConfigMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o ClusterConfigMessage) AppendXDR(bs []byte) []byte { +func (o ClusterConfigMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o ClusterConfigMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o ClusterConfigMessage) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.ClientName) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.ClientName); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("ClientName", l, 64) } xw.WriteString(o.ClientName) - if len(o.ClientVersion) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.ClientVersion); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("ClientVersion", l, 64) } xw.WriteString(o.ClientVersion) - if len(o.Folders) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Folders); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Folders", l, 64) } xw.WriteUint32(uint32(len(o.Folders))) for i := range o.Folders { @@ -575,8 +631,8 @@ func (o ClusterConfigMessage) encodeXDR(xw *xdr.Writer) (int, error) { return xw.Tot(), err } } - if len(o.Options) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Options); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Options", l, 64) } xw.WriteUint32(uint32(len(o.Options))) for i := range o.Options { @@ -604,7 +660,7 @@ func (o *ClusterConfigMessage) decodeXDR(xr *xdr.Reader) error { o.ClientVersion = xr.ReadStringMax(64) _FoldersSize := int(xr.ReadUint32()) if _FoldersSize > 64 { - return xdr.ErrElementSizeExceeded + return xdr.ElementSizeExceeded("Folders", _FoldersSize, 64) } o.Folders = make([]Folder, _FoldersSize) for i := range o.Folders { @@ -612,7 +668,7 @@ func (o *ClusterConfigMessage) decodeXDR(xr *xdr.Reader) error { } _OptionsSize := int(xr.ReadUint32()) if _OptionsSize > 64 { - return xdr.ErrElementSizeExceeded + return xdr.ElementSizeExceeded("Options", _OptionsSize, 64) } o.Options = make([]Option, _OptionsSize) for i := range o.Options { @@ -654,20 +710,28 @@ func (o Folder) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o Folder) MarshalXDR() []byte { +func (o Folder) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o Folder) AppendXDR(bs []byte) []byte { +func (o Folder) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Folder) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o Folder) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.ID) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.ID); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("ID", l, 64) } xw.WriteString(o.ID) xw.WriteUint32(uint32(len(o.Devices))) @@ -735,20 +799,28 @@ func (o Device) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o Device) MarshalXDR() []byte { +func (o Device) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o Device) AppendXDR(bs []byte) []byte { +func (o Device) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Device) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o Device) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.ID) > 32 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.ID); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("ID", l, 32) } xw.WriteBytes(o.ID) xw.WriteUint32(o.Flags) @@ -807,24 +879,32 @@ func (o Option) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o Option) MarshalXDR() []byte { +func (o Option) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o Option) AppendXDR(bs []byte) []byte { +func (o Option) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Option) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o Option) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Key) > 64 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Key); l > 64 { + return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 64) } xw.WriteString(o.Key) - if len(o.Value) > 1024 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Value); l > 1024 { + return xw.Tot(), xdr.ElementSizeExceeded("Value", l, 1024) } xw.WriteString(o.Value) return xw.Tot(), xw.Error() @@ -873,20 +953,28 @@ func (o CloseMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o CloseMessage) MarshalXDR() []byte { +func (o CloseMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o CloseMessage) AppendXDR(bs []byte) []byte { +func (o CloseMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o CloseMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o CloseMessage) encodeXDR(xw *xdr.Writer) (int, error) { - if len(o.Reason) > 1024 { - return xw.Tot(), xdr.ErrElementSizeExceeded + if l := len(o.Reason); l > 1024 { + return xw.Tot(), xdr.ElementSizeExceeded("Reason", l, 1024) } xw.WriteString(o.Reason) return xw.Tot(), xw.Error() @@ -927,15 +1015,23 @@ func (o EmptyMessage) EncodeXDR(w io.Writer) (int, error) { return o.encodeXDR(xw) } -func (o EmptyMessage) MarshalXDR() []byte { +func (o EmptyMessage) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o EmptyMessage) AppendXDR(bs []byte) []byte { +func (o EmptyMessage) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o EmptyMessage) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) - o.encodeXDR(xw) - return []byte(aw) + _, err := o.encodeXDR(xw) + return []byte(aw), err } func (o EmptyMessage) encodeXDR(xw *xdr.Writer) (int, error) { diff --git a/protocol.go b/protocol.go index 19cdfbb16..08ef226f0 100644 --- a/protocol.go +++ b/protocol.go @@ -126,7 +126,7 @@ type hdrMsg struct { } type encodable interface { - AppendXDR([]byte) []byte + AppendXDR([]byte) ([]byte, error) } const ( @@ -483,7 +483,11 @@ func (c *rawConnection) writerLoop() { case hm := <-c.outbox: if hm.msg != nil { // Uncompressed message in uncBuf - uncBuf = hm.msg.AppendXDR(uncBuf[:0]) + uncBuf, err = hm.msg.AppendXDR(uncBuf[:0]) + if err != nil { + c.close(err) + return + } if len(uncBuf) >= c.compressionThreshold { // Use compression for large messages diff --git a/protocol_test.go b/protocol_test.go index cc6f9472a..a7bb1416c 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -25,6 +25,7 @@ import ( "io/ioutil" "os" "reflect" + "strings" "testing" "testing/quick" @@ -369,7 +370,7 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { } _, err := m1.EncodeXDR(&buf) - if err == xdr.ErrElementSizeExceeded { + if err != nil && strings.Contains(err.Error(), "exceeds size") { return true } if err != nil {