lib/protocol: Send Close message on read error (#7141)

This commit is contained in:
Simon Frei 2020-11-27 11:31:20 +01:00 committed by GitHub
parent a9764fc16c
commit bbb22c8c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 23 deletions

View File

@ -170,6 +170,8 @@ type rawConnection struct {
closeOnce sync.Once closeOnce sync.Once
sendCloseOnce sync.Once sendCloseOnce sync.Once
compression Compression compression Compression
loopWG sync.WaitGroup // Need to ensure no leftover routines in testing
} }
type asyncResult struct { type asyncResult struct {
@ -244,20 +246,35 @@ func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, rec
dispatcherLoopStopped: make(chan struct{}), dispatcherLoopStopped: make(chan struct{}),
closed: make(chan struct{}), closed: make(chan struct{}),
compression: compress, compression: compress,
loopWG: sync.WaitGroup{},
} }
} }
// Start creates the goroutines for sending and receiving of messages. It must // Start creates the goroutines for sending and receiving of messages. It must
// be called exactly once after creating a connection. // be called exactly once after creating a connection.
func (c *rawConnection) Start() { func (c *rawConnection) Start() {
go c.readerLoop() c.loopWG.Add(5)
go func() {
c.readerLoop()
c.loopWG.Done()
}()
go func() { go func() {
err := c.dispatcherLoop() err := c.dispatcherLoop()
c.internalClose(err) c.Close(err)
c.loopWG.Done()
}()
go func() {
c.writerLoop()
c.loopWG.Done()
}()
go func() {
c.pingSender()
c.loopWG.Done()
}()
go func() {
c.pingReceiver()
c.loopWG.Done()
}() }()
go c.writerLoop()
go c.pingSender()
go c.pingReceiver()
c.startTime = time.Now() c.startTime = time.Now()
} }
@ -410,7 +427,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
state = stateReady state = stateReady
} }
if err := c.receiver.ClusterConfig(c.id, *msg); err != nil { if err := c.receiver.ClusterConfig(c.id, *msg); err != nil {
return errors.Wrap(err, "receiver error") return fmt.Errorf("receiving cluster config: %w", err)
} }
case *Index: case *Index:
@ -422,7 +439,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return errors.Wrap(err, "protocol error: index") return errors.Wrap(err, "protocol error: index")
} }
if err := c.handleIndex(*msg); err != nil { if err := c.handleIndex(*msg); err != nil {
return errors.Wrap(err, "receiver error") return fmt.Errorf("receiving index: %w", err)
} }
state = stateReady state = stateReady
@ -435,7 +452,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return errors.Wrap(err, "protocol error: index update") return errors.Wrap(err, "protocol error: index update")
} }
if err := c.handleIndexUpdate(*msg); err != nil { if err := c.handleIndexUpdate(*msg); err != nil {
return errors.Wrap(err, "receiver error") return fmt.Errorf("receiving index update: %w", err)
} }
state = stateReady state = stateReady
@ -462,7 +479,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return fmt.Errorf("protocol error: response message in state %d", state) return fmt.Errorf("protocol error: response message in state %d", state)
} }
if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil { if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil {
return errors.Wrap(err, "receiver error") return fmt.Errorf("receiving download progress: %w", err)
} }
case *Ping: case *Ping:
@ -474,7 +491,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
case *Close: case *Close:
l.Debugln("read Close message") l.Debugln("read Close message")
return errors.New(msg.Reason) return fmt.Errorf("closed by remote: %v", msg.Reason)
default: default:
l.Debugf("read unknown message: %+T", msg) l.Debugf("read unknown message: %+T", msg)

View File

@ -33,8 +33,10 @@ func TestPing(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{})
@ -57,8 +59,10 @@ func TestClose(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways) c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{})
@ -97,8 +101,10 @@ func TestCloseOnBlockingSend(t *testing.T) {
m := newTestModel() m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
@ -149,8 +155,10 @@ func TestCloseRace(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever) c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{})
@ -184,8 +192,10 @@ func TestCloseRace(t *testing.T) {
func TestClusterConfigFirst(t *testing.T) { func TestClusterConfigFirst(t *testing.T) {
m := newTestModel() m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw)
select { select {
case c.outbox <- asyncMessage{&Ping{}, nil}: case c.outbox <- asyncMessage{&Ping{}, nil}:
@ -234,8 +244,10 @@ func TestCloseTimeout(t *testing.T) {
m := newTestModel() m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -852,8 +864,10 @@ func TestSha256OfEmptyBlock(t *testing.T) {
func TestClusterConfigAfterClose(t *testing.T) { func TestClusterConfigAfterClose(t *testing.T) {
m := newTestModel() m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw)
c.internalClose(errManual) c.internalClose(errManual)
@ -874,11 +888,13 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
// Verify that we don't deadlock when calling Close() from within one of // Verify that we don't deadlock when calling Close() from within one of
// the model callbacks (ClusterConfig). // the model callbacks (ClusterConfig).
m := newTestModel() m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
m.ccFn = func(devID DeviceID, cc ClusterConfig) { m.ccFn = func(devID DeviceID, cc ClusterConfig) {
c.Close(errManual) c.Close(errManual)
} }
c.Start() c.Start()
defer closeAndWait(c, rw)
c.inbox <- &ClusterConfig{} c.inbox <- &ClusterConfig{}
@ -945,3 +961,18 @@ func TestIndexIDString(t *testing.T) {
t.Error(i.String()) t.Error(i.String())
} }
} }
func closeAndWait(c Connection, closers ...io.Closer) {
for _, closer := range closers {
closer.Close()
}
var raw *rawConnection
switch i := c.(type) {
case wireFormatConnection:
raw = i.Connection.(*rawConnection)
case *rawConnection:
raw = i
}
raw.internalClose(ErrClosed)
raw.loopWG.Wait()
}

View File

@ -6,17 +6,40 @@
package testutils package testutils
// BlockingRW implements io.Reader and Writer but never returns when called import (
type BlockingRW struct{ nilChan chan struct{} } "errors"
"sync"
)
func (rw *BlockingRW) Read(p []byte) (n int, err error) { var ErrClosed = errors.New("closed")
<-rw.nilChan
return // BlockingRW implements io.Reader, Writer and Closer, but only returns when closed
type BlockingRW struct {
c chan struct{}
closeOnce sync.Once
} }
func (rw *BlockingRW) Write(p []byte) (n int, err error) { func NewBlockingRW() *BlockingRW {
<-rw.nilChan return &BlockingRW{
return c: make(chan struct{}),
closeOnce: sync.Once{},
}
}
func (rw *BlockingRW) Read(p []byte) (int, error) {
<-rw.c
return 0, ErrClosed
}
func (rw *BlockingRW) Write(p []byte) (int, error) {
<-rw.c
return 0, ErrClosed
}
func (rw *BlockingRW) Close() error {
rw.closeOnce.Do(func() {
close(rw.c)
})
return nil
} }
// NoopRW implements io.Reader and Writer but never returns when called // NoopRW implements io.Reader and Writer but never returns when called