diff --git a/cmd/genxdr/main.go b/cmd/genxdr/main.go index d6488096a..40e2e184f 100644 --- a/cmd/genxdr/main.go +++ b/cmd/genxdr/main.go @@ -84,7 +84,10 @@ func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) { {{end}} xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}}) {{else}} - o.{{$fieldInfo.Name}}.encodeXDR(xw) + _, err := o.{{$fieldInfo.Name}}.encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } {{end}} {{else}} {{if ge $fieldInfo.Max 1}} @@ -99,7 +102,10 @@ func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) { {{else if $fieldInfo.IsBasic}} xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}}[i]) {{else}} - o.{{$fieldInfo.Name}}[i].encodeXDR(xw) + _, err := o.{{$fieldInfo.Name}}[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } {{end}} } {{end}} diff --git a/discover/packets_xdr.go b/discover/packets_xdr.go index 96e4fdc17..1c5b16cea 100644 --- a/discover/packets_xdr.go +++ b/discover/packets_xdr.go @@ -126,13 +126,19 @@ func (o Announce) AppendXDR(bs []byte) []byte { func (o Announce) encodeXDR(xw *xdr.Writer) (int, error) { xw.WriteUint32(o.Magic) - o.This.encodeXDR(xw) + _, err := o.This.encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } if len(o.Extra) > 16 { return xw.Tot(), xdr.ErrElementSizeExceeded } xw.WriteUint32(uint32(len(o.Extra))) for i := range o.Extra { - o.Extra[i].encodeXDR(xw) + _, err := o.Extra[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } @@ -216,7 +222,10 @@ func (o Node) encodeXDR(xw *xdr.Writer) (int, error) { } xw.WriteUint32(uint32(len(o.Addresses))) for i := range o.Addresses { - o.Addresses[i].encodeXDR(xw) + _, err := o.Addresses[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } diff --git a/files/leveldb_xdr.go b/files/leveldb_xdr.go index 39fef82b3..e939087ed 100644 --- a/files/leveldb_xdr.go +++ b/files/leveldb_xdr.go @@ -120,7 +120,10 @@ func (o versionList) AppendXDR(bs []byte) []byte { func (o versionList) encodeXDR(xw *xdr.Writer) (int, error) { xw.WriteUint32(uint32(len(o.versions))) for i := range o.versions { - o.versions[i].encodeXDR(xw) + _, err := o.versions[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } diff --git a/protocol/.gitignore b/protocol/.gitignore new file mode 100644 index 000000000..2211df63d --- /dev/null +++ b/protocol/.gitignore @@ -0,0 +1 @@ +*.txt diff --git a/protocol/message_xdr.go b/protocol/message_xdr.go index 917aab6ff..323e7998b 100644 --- a/protocol/message_xdr.go +++ b/protocol/message_xdr.go @@ -66,7 +66,10 @@ func (o IndexMessage) encodeXDR(xw *xdr.Writer) (int, error) { xw.WriteString(o.Repository) xw.WriteUint32(uint32(len(o.Files))) for i := range o.Files { - o.Files[i].encodeXDR(xw) + _, err := o.Files[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } @@ -165,7 +168,10 @@ func (o FileInfo) encodeXDR(xw *xdr.Writer) (int, error) { xw.WriteUint64(o.LocalVersion) xw.WriteUint32(uint32(len(o.Blocks))) for i := range o.Blocks { - o.Blocks[i].encodeXDR(xw) + _, err := o.Blocks[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } @@ -476,14 +482,20 @@ func (o ClusterConfigMessage) encodeXDR(xw *xdr.Writer) (int, error) { } xw.WriteUint32(uint32(len(o.Repositories))) for i := range o.Repositories { - o.Repositories[i].encodeXDR(xw) + _, err := o.Repositories[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } if len(o.Options) > 64 { return xw.Tot(), xdr.ErrElementSizeExceeded } xw.WriteUint32(uint32(len(o.Options))) for i := range o.Options { - o.Options[i].encodeXDR(xw) + _, err := o.Options[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } @@ -575,7 +587,10 @@ func (o Repository) encodeXDR(xw *xdr.Writer) (int, error) { } xw.WriteUint32(uint32(len(o.Nodes))) for i := range o.Nodes { - o.Nodes[i].encodeXDR(xw) + _, err := o.Nodes[i].encodeXDR(xw) + if err != nil { + return xw.Tot(), err + } } return xw.Tot(), xw.Error() } diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index d8e9c3fd1..7b9d71e71 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -5,12 +5,19 @@ package protocol import ( + "bytes" + "encoding/hex" "errors" + "fmt" "io" + "io/ioutil" + "os" + "reflect" "testing" "testing/quick" "github.com/calmh/syncthing/xdr" + pretty "github.com/tonnerre/golang-pretty" ) var ( @@ -230,3 +237,124 @@ func TestClose(t *testing.T) { t.Error("Request should return an error") } } + +func TestElementSizeExceededNested(t *testing.T) { + m := ClusterConfigMessage{ + Repositories: []Repository{ + {ID: "longstringlongstringlongstringinglongstringlongstringlonlongstringlongstringlon"}, + }, + } + _, err := m.EncodeXDR(ioutil.Discard) + if err == nil { + t.Errorf("ID length %d > max 64, but no error", len(m.Repositories[0].ID)) + } +} + +func TestMarshalIndexMessage(t *testing.T) { + f := func(m1 IndexMessage) bool { + for _, f := range m1.Files { + for i := range f.Blocks { + f.Blocks[i].Offset = 0 + if len(f.Blocks[i].Hash) == 0 { + f.Blocks[i].Hash = nil + } + } + } + + return testMarshal(t, "index", &m1, &IndexMessage{}) + } + + if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil { + t.Error(err) + } +} + +func TestMarshalRequestMessage(t *testing.T) { + f := func(m1 RequestMessage) bool { + return testMarshal(t, "request", &m1, &RequestMessage{}) + } + + if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil { + t.Error(err) + } +} + +func TestMarshalResponseMessage(t *testing.T) { + f := func(m1 ResponseMessage) bool { + if len(m1.Data) == 0 { + m1.Data = nil + } + return testMarshal(t, "response", &m1, &ResponseMessage{}) + } + + if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil { + t.Error(err) + } +} + +func TestMarshalClusterConfigMessage(t *testing.T) { + f := func(m1 ClusterConfigMessage) bool { + return testMarshal(t, "clusterconfig", &m1, &ClusterConfigMessage{}) + } + + if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil { + t.Error(err) + } +} + +func TestMarshalCloseMessage(t *testing.T) { + f := func(m1 CloseMessage) bool { + return testMarshal(t, "close", &m1, &CloseMessage{}) + } + + if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil { + t.Error(err) + } +} + +type message interface { + EncodeXDR(io.Writer) (int, error) + DecodeXDR(io.Reader) error +} + +func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { + var buf bytes.Buffer + + failed := func(bc []byte) { + f, _ := os.Create(prefix + "-1.txt") + pretty.Fprintf(f, "%# v", m1) + f.Close() + f, _ = os.Create(prefix + "-2.txt") + pretty.Fprintf(f, "%# v", m2) + f.Close() + if len(bc) > 0 { + f, _ := os.Create(prefix + "-data.txt") + fmt.Fprint(f, hex.Dump(bc)) + f.Close() + } + } + + _, err := m1.EncodeXDR(&buf) + if err == xdr.ErrElementSizeExceeded { + return true + } + if err != nil { + failed(nil) + t.Fatal(err) + } + + bc := make([]byte, len(buf.Bytes())) + copy(bc, buf.Bytes()) + + err = m2.DecodeXDR(&buf) + if err != nil { + failed(bc) + t.Fatal(err) + } + + ok := reflect.DeepEqual(m1, m2) + if !ok { + failed(bc) + } + return ok +}