lib/connections, lib/model: Refactor connection close handling (fixes #3466)

So there were some issues here. The main problem was that
model.Close(deviceID) was overloaded to mean "the connection was closed
by the protocol layer" and "i want to close this connection". That meant
it could get called twice - once *to* close the connection and then once
more when the connection *was* closed.

After this refactor there is instead a Closed(conn) method that is the
callback. I didn't need to change the parameter in the end, but I think
it's clearer what it means when it takes the connection that was closed
instead of a device ID. To close a connection, the new close(deviceID)
method is used instead, which only closes the underlying connection and
leaves the cleanup to the Closed() callback.

I also changed how we do connection switching. Instead of the connection
service calling close and then adding the connection, it just adds the
new connection. The model knows that it already has a connection and
makes sure to close and clean out that one before adding the new
connection.

To make sure to sequence this properly I added a new map of channels
that get created on connection add and closed by Closed(), so that
AddConnection() can do the close and wait for the cleanup to happen
before proceeding.

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3490
This commit is contained in:
Jakob Borg 2016-08-10 09:37:32 +00:00 committed by Audrius Butkevicius
parent c9cf01e0b6
commit e52be3d83e
7 changed files with 63 additions and 26 deletions

View File

@ -202,7 +202,6 @@ next:
// Lower priority is better, just like nice etc. // Lower priority is better, just like nice etc.
if priorityKnown && ct.Priority > c.Priority { if priorityKnown && ct.Priority > c.Priority {
l.Debugln("Switching connections", remoteID) l.Debugln("Switching connections", remoteID)
s.model.Close(remoteID, protocol.ErrSwitchingConnections)
} else if connected { } else if connected {
// We should not already be connected to the other party. TODO: This // We should not already be connected to the other party. TODO: This
// could use some better handling. If the old connection is dead but // could use some better handling. If the old connection is dead but

View File

@ -8,6 +8,7 @@ package connections
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/url" "net/url"
"time" "time"
@ -28,6 +29,10 @@ type Connection struct {
protocol.Connection protocol.Connection
} }
func (c Connection) String() string {
return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.Type)
}
type dialerFactory interface { type dialerFactory interface {
New(*config.Wrapper, *tls.Config) genericDialer New(*config.Wrapper, *tls.Config) genericDialer
Priority() int Priority() int

View File

@ -94,6 +94,7 @@ type Model struct {
fmut sync.RWMutex // protects the above fmut sync.RWMutex // protects the above
conn map[protocol.DeviceID]connections.Connection conn map[protocol.DeviceID]connections.Connection
closed map[protocol.DeviceID]chan struct{}
helloMessages map[protocol.DeviceID]protocol.HelloResult helloMessages map[protocol.DeviceID]protocol.HelloResult
devicePaused map[protocol.DeviceID]bool devicePaused map[protocol.DeviceID]bool
deviceDownloads map[protocol.DeviceID]*deviceDownloadState deviceDownloads map[protocol.DeviceID]*deviceDownloadState
@ -152,6 +153,7 @@ func NewModel(cfg *config.Wrapper, id protocol.DeviceID, deviceName, clientName,
folderRunnerTokens: make(map[string][]suture.ServiceToken), folderRunnerTokens: make(map[string][]suture.ServiceToken),
folderStatRefs: make(map[string]*stats.FolderStatisticsReference), folderStatRefs: make(map[string]*stats.FolderStatisticsReference),
conn: make(map[protocol.DeviceID]connections.Connection), conn: make(map[protocol.DeviceID]connections.Connection),
closed: make(map[protocol.DeviceID]chan struct{}),
helloMessages: make(map[protocol.DeviceID]protocol.HelloResult), helloMessages: make(map[protocol.DeviceID]protocol.HelloResult),
devicePaused: make(map[protocol.DeviceID]bool), devicePaused: make(map[protocol.DeviceID]bool),
deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState), deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState),
@ -912,25 +914,42 @@ func (m *Model) ClusterConfig(deviceID protocol.DeviceID, cm protocol.ClusterCon
} }
} }
// Close removes the peer from the model and closes the underlying connection if possible. // Closed is called when a connection has been closed
// Implements the protocol.Model interface. func (m *Model) Closed(conn protocol.Connection, err error) {
func (m *Model) Close(device protocol.DeviceID, err error) { device := conn.ID()
l.Infof("Connection to %s closed: %v", device, err)
events.Default.Log(events.DeviceDisconnected, map[string]string{
"id": device.String(),
"error": err.Error(),
})
m.pmut.Lock() m.pmut.Lock()
conn, ok := m.conn[device] conn, ok := m.conn[device]
if ok { if ok {
m.progressEmitter.temporaryIndexUnsubscribe(conn) m.progressEmitter.temporaryIndexUnsubscribe(conn)
closeRawConn(conn)
} }
delete(m.conn, device) delete(m.conn, device)
delete(m.helloMessages, device) delete(m.helloMessages, device)
delete(m.deviceDownloads, device) delete(m.deviceDownloads, device)
closed := m.closed[device]
delete(m.closed, device)
m.pmut.Unlock() m.pmut.Unlock()
l.Infof("Connection to %s closed: %v", device, err)
events.Default.Log(events.DeviceDisconnected, map[string]string{
"id": device.String(),
"error": err.Error(),
})
close(closed)
}
// close will close the underlying connection for a given device
func (m *Model) close(device protocol.DeviceID) {
m.pmut.Lock()
conn, ok := m.conn[device]
m.pmut.Unlock()
if !ok {
// There is no connection to close
return
}
closeRawConn(conn)
} }
// Request returns the specified data segment by reading it from local disk. // Request returns the specified data segment by reading it from local disk.
@ -1171,10 +1190,22 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
deviceID := conn.ID() deviceID := conn.ID()
m.pmut.Lock() m.pmut.Lock()
if _, ok := m.conn[deviceID]; ok { if oldConn, ok := m.conn[deviceID]; ok {
panic("add existing device") l.Infoln("Replacing old connection", oldConn, "with", conn, "for", deviceID)
// There is an existing connection to this device that we are
// replacing. We must close the existing connection and wait for the
// close to complete before adding the new connection. We do the
// actual close without holding pmut as the connection will call
// back into Closed() for the cleanup.
closed := m.closed[deviceID]
m.pmut.Unlock()
closeRawConn(oldConn)
<-closed
m.pmut.Lock()
} }
m.conn[deviceID] = conn m.conn[deviceID] = conn
m.closed[deviceID] = make(chan struct{})
m.deviceDownloads[deviceID] = newDeviceDownloadState() m.deviceDownloads[deviceID] = newDeviceDownloadState()
m.helloMessages[deviceID] = hello m.helloMessages[deviceID] = hello
@ -1215,10 +1246,10 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
func (m *Model) PauseDevice(device protocol.DeviceID) { func (m *Model) PauseDevice(device protocol.DeviceID) {
m.pmut.Lock() m.pmut.Lock()
m.devicePaused[device] = true m.devicePaused[device] = true
_, ok := m.conn[device] conn, ok := m.conn[device]
m.pmut.Unlock() m.pmut.Unlock()
if ok { if ok {
m.Close(device, errors.New("device paused")) closeRawConn(conn)
} }
events.Default.Log(events.DevicePaused, map[string]string{"device": device.String()}) events.Default.Log(events.DevicePaused, map[string]string{"device": device.String()})
} }

View File

@ -351,7 +351,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device already has a name") t.Errorf("Device already has a name")
} }
m.Close(device1, protocol.ErrTimeout) m.Closed(conn, protocol.ErrTimeout)
hello.DeviceName = "tester" hello.DeviceName = "tester"
m.AddConnection(conn, hello) m.AddConnection(conn, hello)
@ -359,7 +359,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device did not get a name") t.Errorf("Device did not get a name")
} }
m.Close(device1, protocol.ErrTimeout) m.Closed(conn, protocol.ErrTimeout)
hello.DeviceName = "tester2" hello.DeviceName = "tester2"
m.AddConnection(conn, hello) m.AddConnection(conn, hello)
@ -376,7 +376,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device name not saved in config") t.Errorf("Device name not saved in config")
} }
m.Close(device1, protocol.ErrTimeout) m.Closed(conn, protocol.ErrTimeout)
opts := cfg.Options() opts := cfg.Options()
opts.OverwriteRemoteDevNames = true opts.OverwriteRemoteDevNames = true
@ -1527,7 +1527,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
m.StartFolder(fcfg.ID) m.StartFolder(fcfg.ID)
m.ServeBackground() m.ServeBackground()
m.AddConnection(connections.Connection{ conn1 := connections.Connection{
IntermediateConnection: connections.IntermediateConnection{ IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil), Conn: tls.Client(&fakeConn{}, nil),
Type: "foo", Type: "foo",
@ -1536,8 +1536,9 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
Connection: &FakeConnection{ Connection: &FakeConnection{
id: device1, id: device1,
}, },
}, protocol.HelloResult{}) }
m.AddConnection(connections.Connection{ m.AddConnection(conn1, protocol.HelloResult{})
conn2 := connections.Connection{
IntermediateConnection: connections.IntermediateConnection{ IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(d2c, nil), Conn: tls.Client(d2c, nil),
Type: "foo", Type: "foo",
@ -1546,7 +1547,8 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
Connection: &FakeConnection{ Connection: &FakeConnection{
id: device2, id: device2,
}, },
}, protocol.HelloResult{}) }
m.AddConnection(conn2, protocol.HelloResult{})
m.ClusterConfig(device1, protocol.ClusterConfig{ m.ClusterConfig(device1, protocol.ClusterConfig{
Folders: []protocol.Folder{ Folders: []protocol.Folder{
@ -1629,7 +1631,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
t.Error("downloads missing early") t.Error("downloads missing early")
} }
m.Close(device2, fmt.Errorf("foo")) m.Closed(conn2, fmt.Errorf("foo"))
if _, ok := m.conn[device2]; ok { if _, ok := m.conn[device2]; ok {
t.Error("conn not missing") t.Error("conn not missing")

View File

@ -181,7 +181,7 @@ func (m *fakeModel) Request(deviceID DeviceID, folder string, name string, offse
func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) { func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) {
} }
func (m *fakeModel) Close(deviceID DeviceID, err error) { func (m *fakeModel) Closed(conn Connection, err error) {
} }
func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) { func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) {

View File

@ -39,7 +39,7 @@ func (t *TestModel) Request(deviceID DeviceID, folder, name string, offset int64
return nil return nil
} }
func (t *TestModel) Close(deviceID DeviceID, err error) { func (t *TestModel) Closed(conn Connection, err error) {
t.closedErr = err t.closedErr = err
close(t.closedCh) close(t.closedCh)
} }

View File

@ -67,7 +67,7 @@ type Model interface {
// A cluster configuration message was received // A cluster configuration message was received
ClusterConfig(deviceID DeviceID, config ClusterConfig) ClusterConfig(deviceID DeviceID, config ClusterConfig)
// The peer device closed the connection // The peer device closed the connection
Close(deviceID DeviceID, err error) Closed(conn Connection, err error)
// The peer device sent progress updates for the files it is currently downloading // The peer device sent progress updates for the files it is currently downloading
DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate)
} }
@ -729,7 +729,7 @@ func (c *rawConnection) close(err error) {
} }
c.awaitingMut.Unlock() c.awaitingMut.Unlock()
go c.receiver.Close(c.id, err) c.receiver.Closed(c, err)
}) })
} }