lib/protocol: Don't call receiver after calling Closed (fixes #4170) (#5742)

* lib/protocol: Don't call receiver after calling Closed (fixes #4170)

* review
This commit is contained in:
Simon Frei 2019-05-25 21:08:07 +02:00 committed by Audrius Butkevicius
parent d91da8feee
commit 9e6db72535
3 changed files with 97 additions and 27 deletions

View File

@ -13,6 +13,7 @@ type TestModel struct {
hash []byte hash []byte
weakHash uint32 weakHash uint32
fromTemporary bool fromTemporary bool
indexFn func(DeviceID, string, []FileInfo)
closedCh chan struct{} closedCh chan struct{}
closedErr error closedErr error
} }
@ -24,6 +25,9 @@ func newTestModel() *TestModel {
} }
func (t *TestModel) Index(deviceID DeviceID, folder string, files []FileInfo) { func (t *TestModel) Index(deviceID DeviceID, folder string, files []FileInfo) {
if t.indexFn != nil {
t.indexFn(deviceID, folder, files)
}
} }
func (t *TestModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) { func (t *TestModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) {

View File

@ -182,11 +182,13 @@ type rawConnection struct {
nextID int32 nextID int32
nextIDMut sync.Mutex nextIDMut sync.Mutex
outbox chan asyncMessage inbox chan message
closed chan struct{} outbox chan asyncMessage
closeOnce sync.Once dispatcherLoopStopped chan struct{}
sendCloseOnce sync.Once closed chan struct{}
compression Compression closeOnce sync.Once
sendCloseOnce sync.Once
compression Compression
} }
type asyncResult struct { type asyncResult struct {
@ -220,15 +222,17 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
cw := &countingWriter{Writer: writer} cw := &countingWriter{Writer: writer}
c := rawConnection{ c := rawConnection{
id: deviceID, id: deviceID,
name: name, name: name,
receiver: nativeModel{receiver}, receiver: nativeModel{receiver},
cr: cr, cr: cr,
cw: cw, cw: cw,
awaiting: make(map[int32]chan asyncResult), awaiting: make(map[int32]chan asyncResult),
outbox: make(chan asyncMessage), inbox: make(chan message),
closed: make(chan struct{}), outbox: make(chan asyncMessage),
compression: compress, dispatcherLoopStopped: make(chan struct{}),
closed: make(chan struct{}),
compression: compress,
} }
return wireFormatConnection{&c} return wireFormatConnection{&c}
@ -237,8 +241,9 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
// Start creates the goroutines for sending and receiving of messages. It must // Start creates the goroutines for sending and receiving of messages. It must
// be called exactly once after creating a connection. // be called exactly once after creating a connection.
func (c *rawConnection) Start() { func (c *rawConnection) Start() {
go c.readerLoop()
go func() { go func() {
err := c.readerLoop() err := c.dispatcherLoop()
c.internalClose(err) c.internalClose(err)
}() }()
go c.writerLoop() go c.writerLoop()
@ -348,25 +353,37 @@ func (c *rawConnection) ping() bool {
return c.send(&Ping{}, nil) return c.send(&Ping{}, nil)
} }
func (c *rawConnection) readerLoop() (err error) { func (c *rawConnection) readerLoop() {
fourByteBuf := make([]byte, 4) fourByteBuf := make([]byte, 4)
for {
msg, err := c.readMessage(fourByteBuf)
if err != nil {
if err == errUnknownMessage {
// Unknown message types are skipped, for future extensibility.
continue
}
c.internalClose(err)
return
}
select {
case c.inbox <- msg:
case <-c.closed:
return
}
}
}
func (c *rawConnection) dispatcherLoop() (err error) {
defer close(c.dispatcherLoopStopped)
var msg message
state := stateInitial state := stateInitial
for { for {
select { select {
case msg = <-c.inbox:
case <-c.closed: case <-c.closed:
return ErrClosed return ErrClosed
default:
} }
msg, err := c.readMessage(fourByteBuf)
if err == errUnknownMessage {
// Unknown message types are skipped, for future extensibility.
continue
}
if err != nil {
return err
}
switch msg := msg.(type) { switch msg := msg.(type) {
case *ClusterConfig: case *ClusterConfig:
l.Debugln("read ClusterConfig message") l.Debugln("read ClusterConfig message")
@ -847,6 +864,8 @@ func (c *rawConnection) internalClose(err error) {
} }
c.awaitingMut.Unlock() c.awaitingMut.Unlock()
<-c.dispatcherLoopStopped
c.receiver.Closed(c, err) c.receiver.Closed(c, err)
}) })
} }

View File

@ -125,6 +125,53 @@ func TestCloseOnBlockingSend(t *testing.T) {
} }
} }
func TestCloseRace(t *testing.T) {
indexReceived := make(chan struct{})
unblockIndex := make(chan struct{})
m0 := newTestModel()
m0.indexFn = func(_ DeviceID, _ string, _ []FileInfo) {
close(indexReceived)
<-unblockIndex
}
m1 := newTestModel()
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressNever).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressNever)
c1.Start()
c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{})
c1.Index("default", nil)
select {
case <-indexReceived:
case <-time.After(time.Second):
t.Fatal("timed out before receiving index")
}
go c0.internalClose(errManual)
select {
case <-c0.closed:
case <-time.After(time.Second):
t.Fatal("timed out before c0.closed was closed")
}
select {
case <-m0.closedCh:
t.Errorf("receiver.Closed called before receiver.Index")
default:
}
close(unblockIndex)
if err := m0.closedError(); err != errManual {
t.Fatal("Connection should be closed")
}
}
func TestMarshalIndexMessage(t *testing.T) { func TestMarshalIndexMessage(t *testing.T) {
if testing.Short() { if testing.Short() {
quickCfg.MaxCount = 10 quickCfg.MaxCount = 10