diff --git a/lib/connections/service.go b/lib/connections/service.go index e61a8c082..fa0ee7093 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -335,13 +335,7 @@ func (s *service) handle(ctx context.Context) error { isLAN := s.isLAN(c.RemoteAddr()) rd, wr := s.limiter.getLimiters(remoteID, c, isLAN) - var protoConn protocol.Connection - passwords := s.cfg.FolderPasswords(remoteID) - if len(passwords) > 0 { - protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression) - } else { - protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression) - } + protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID)) l.Infof("Established secure connection to %s at %s", remoteID, c) diff --git a/lib/model/fakeconns_test.go b/lib/model/fakeconns_test.go index c50137865..b1a8e91ef 100644 --- a/lib/model/fakeconns_test.go +++ b/lib/model/fakeconns_test.go @@ -33,7 +33,7 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection { }) f.IDReturns(id) f.CloseCalls(func(err error) { - model.Closed(f, err) + model.Closed(id, err) f.ClosedReturns(true) }) return f diff --git a/lib/model/mocks/model.go b/lib/model/mocks/model.go index 9e4048b19..8491d87d1 100644 --- a/lib/model/mocks/model.go +++ b/lib/model/mocks/model.go @@ -43,10 +43,10 @@ type Model struct { arg1 string arg2 string } - ClosedStub func(protocol.Connection, error) + ClosedStub func(protocol.DeviceID, error) closedMutex sync.RWMutex closedArgsForCall []struct { - arg1 protocol.Connection + arg1 protocol.DeviceID arg2 error } ClusterConfigStub func(protocol.DeviceID, protocol.ClusterConfig) error @@ -684,10 +684,10 @@ func (fake *Model) BringToFrontArgsForCall(i int) (string, string) { return argsForCall.arg1, argsForCall.arg2 } -func (fake *Model) Closed(arg1 protocol.Connection, arg2 error) { +func (fake *Model) Closed(arg1 protocol.DeviceID, arg2 error) { fake.closedMutex.Lock() fake.closedArgsForCall = append(fake.closedArgsForCall, struct { - arg1 protocol.Connection + arg1 protocol.DeviceID arg2 error }{arg1, arg2}) stub := fake.ClosedStub @@ -704,13 +704,13 @@ func (fake *Model) ClosedCallCount() int { return len(fake.closedArgsForCall) } -func (fake *Model) ClosedCalls(stub func(protocol.Connection, error)) { +func (fake *Model) ClosedCalls(stub func(protocol.DeviceID, error)) { fake.closedMutex.Lock() defer fake.closedMutex.Unlock() fake.ClosedStub = stub } -func (fake *Model) ClosedArgsForCall(i int) (protocol.Connection, error) { +func (fake *Model) ClosedArgsForCall(i int) (protocol.DeviceID, error) { fake.closedMutex.RLock() defer fake.closedMutex.RUnlock() argsForCall := fake.closedArgsForCall[i] diff --git a/lib/model/model.go b/lib/model/model.go index 8f65ae373..c0024a02b 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -293,7 +293,7 @@ func (m *model) initFolders(cfg config.Configuration) error { ignoredDevices := observedDeviceSet(m.cfg.IgnoredDevices()) m.cleanPending(cfg.DeviceMap(), cfg.FolderMap(), ignoredDevices, nil) - m.resendClusterConfig(clusterConfigDevices.AsSlice()) + m.sendClusterConfig(clusterConfigDevices.AsSlice()) return nil } @@ -1510,7 +1510,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice m.fmut.Unlock() // We can only announce ourselfs once we have the token, // thus we need to resend CCs now that we have it. - m.resendClusterConfig(fcfg.DeviceIDs()) + m.sendClusterConfig(fcfg.DeviceIDs()) return nil } } @@ -1520,7 +1520,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice return nil } -func (m *model) resendClusterConfig(ids []protocol.DeviceID) { +func (m *model) sendClusterConfig(ids []protocol.DeviceID) { if len(ids) == 0 { return } @@ -1534,7 +1534,8 @@ func (m *model) resendClusterConfig(ids []protocol.DeviceID) { m.pmut.RUnlock() // Generating cluster-configs acquires fmut -> must happen outside of pmut. for _, conn := range ccConns { - cm := m.generateClusterConfig(conn.ID()) + cm, passwords := m.generateClusterConfig(conn.ID()) + conn.SetFolderPasswords(passwords) go conn.ClusterConfig(cm) } } @@ -1728,9 +1729,7 @@ func (m *model) introduceDevice(device protocol.Device, introducerCfg config.Dev } // Closed is called when a connection has been closed -func (m *model) Closed(conn protocol.Connection, err error) { - device := conn.ID() - +func (m *model) Closed(device protocol.DeviceID, err error) { m.pmut.Lock() conn, ok := m.conn[device] if !ok { @@ -2247,7 +2246,8 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) { m.pmut.Unlock() // Acquires fmut, so has to be done outside of pmut. - cm := m.generateClusterConfig(deviceID) + cm, passwords := m.generateClusterConfig(deviceID) + conn.SetFolderPasswords(passwords) conn.ClusterConfig(cm) if (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" { @@ -2407,15 +2407,17 @@ func (m *model) numHashers(folder string) int { return 1 } -// generateClusterConfig returns a ClusterConfigMessage that is correct for -// the given peer device -func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.ClusterConfig { +// generateClusterConfig returns a ClusterConfigMessage that is correct and the +// set of folder passwords for the given peer device +func (m *model) generateClusterConfig(device protocol.DeviceID) (protocol.ClusterConfig, map[string]string) { var message protocol.ClusterConfig m.fmut.RLock() defer m.fmut.RUnlock() - for _, folderCfg := range m.cfg.FolderList() { + folders := m.cfg.FolderList() + passwords := make(map[string]string, len(folders)) + for _, folderCfg := range folders { if !folderCfg.SharedWith(device) { continue } @@ -2448,8 +2450,8 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster // another cluster config once the folder is started. protocolFolder.Paused = folderCfg.Paused || fs == nil - for _, device := range folderCfg.Devices { - deviceCfg, _ := m.cfg.Device(device.DeviceID) + for _, folderDevice := range folderCfg.Devices { + deviceCfg, _ := m.cfg.Device(folderDevice.DeviceID) protocolDevice := protocol.Device{ ID: deviceCfg.DeviceID, @@ -2462,8 +2464,11 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster if deviceCfg.DeviceID == m.id && hasEncryptionToken { protocolDevice.EncryptionPasswordToken = encryptionToken - } else if device.EncryptionPassword != "" { - protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, device.EncryptionPassword) + } else if folderDevice.EncryptionPassword != "" { + protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, folderDevice.EncryptionPassword) + if folderDevice.DeviceID == device { + passwords[folderCfg.ID] = folderDevice.EncryptionPassword + } } if fs != nil { @@ -2482,7 +2487,7 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster message.Folders = append(message.Folders, protocolFolder) } - return message + return message, passwords } func (m *model) State(folder string) (string, time.Time, error) { @@ -2891,7 +2896,7 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool { } m.pmut.RUnlock() // Generating cluster-configs acquires fmut -> must happen outside of pmut. - m.resendClusterConfig(clusterConfigDevices.AsSlice()) + m.sendClusterConfig(clusterConfigDevices.AsSlice()) ignoredDevices := observedDeviceSet(to.IgnoredDevices) m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders) diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 736dfe628..096d2aa9b 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -341,7 +341,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device already has a name") } - m.Closed(conn, protocol.ErrTimeout) + m.Closed(conn.ID(), protocol.ErrTimeout) hello.DeviceName = "tester" m.AddConnection(conn, hello) @@ -349,7 +349,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device did not get a name") } - m.Closed(conn, protocol.ErrTimeout) + m.Closed(conn.ID(), protocol.ErrTimeout) hello.DeviceName = "tester2" m.AddConnection(conn, hello) @@ -367,7 +367,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device name not saved in config") } - m.Closed(conn, protocol.ErrTimeout) + m.Closed(conn.ID(), protocol.ErrTimeout) waiter, err := cfg.Modify(func(cfg *config.Configuration) { cfg.Options.OverwriteRemoteDevNames = true @@ -428,7 +428,7 @@ func TestClusterConfig(t *testing.T) { m.ServeBackground() defer cleanupModel(m) - cm := m.generateClusterConfig(device2) + cm, _ := m.generateClusterConfig(device2) if l := len(cm.Folders); l != 2 { t.Fatalf("Incorrect number of folders %d != 2", l) @@ -853,7 +853,7 @@ func TestIssue4897(t *testing.T) { defer cleanupModel(m) cancel() - cm := m.generateClusterConfig(device1) + cm, _ := m.generateClusterConfig(device1) if l := len(cm.Folders); l != 1 { t.Errorf("Cluster config contains %v folders, expected 1", l) } @@ -873,7 +873,7 @@ func TestIssue5063(t *testing.T) { for _, c := range m.conn { conn := c.(*fakeConnection) conn.CloseCalls(func(_ error) {}) - defer m.Closed(c, errStopped) // to unblock deferred m.Stop() + defer m.Closed(c.ID(), errStopped) // to unblock deferred m.Stop() } m.pmut.Unlock() @@ -2428,8 +2428,8 @@ func TestNoRequestsFromPausedDevices(t *testing.T) { t.Errorf("should have two available") } - m.Closed(newFakeConnection(device1, m), errDeviceUnknown) - m.Closed(newFakeConnection(device2, m), errDeviceUnknown) + m.Closed(device1, errDeviceUnknown) + m.Closed(device2, errDeviceUnknown) avail = m.testAvailability("default", file, file.Blocks[0]) if len(avail) != 0 { @@ -3171,7 +3171,7 @@ func TestConnCloseOnRestart(t *testing.T) { br := &testutils.BlockingRW{} nw := &testutils.NoopRW{} - m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever), protocol.Hello{}) + m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever, nil), protocol.Hello{}) m.pmut.RLock() if len(m.closed) != 1 { t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn)) @@ -4142,7 +4142,7 @@ func TestCCFolderNotRunning(t *testing.T) { defer cleanupModelAndRemoveDir(m, tfs.URI()) // A connection can happen before all the folders are started. - cc := m.generateClusterConfig(device1) + cc, _ := m.generateClusterConfig(device1) if l := len(cc.Folders); l != 1 { t.Fatalf("Expected 1 folder in CC, got %v", l) } diff --git a/lib/protocol/benchmark_test.go b/lib/protocol/benchmark_test.go index 5b005f16b..1b3bd304e 100644 --- a/lib/protocol/benchmark_test.go +++ b/lib/protocol/benchmark_test.go @@ -60,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) { func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) { // Start up Connections on them - c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata) + c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil) c0.Start() - c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata) + c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil) c1.Start() // Satisfy the assertions in the protocol by sending an initial cluster config @@ -188,7 +188,7 @@ func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) error return nil } -func (m *fakeModel) Closed(conn Connection, err error) { +func (m *fakeModel) Closed(DeviceID, error) { } func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error { diff --git a/lib/protocol/common_test.go b/lib/protocol/common_test.go index 29a5b1d29..6e5c7f02c 100644 --- a/lib/protocol/common_test.go +++ b/lib/protocol/common_test.go @@ -49,7 +49,7 @@ func (t *TestModel) Request(deviceID DeviceID, folder, name string, blockNo, siz return &fakeRequestResponse{buf}, nil } -func (t *TestModel) Closed(conn Connection, err error) { +func (t *TestModel) Closed(_ DeviceID, err error) { t.closedErr = err close(t.closedCh) } diff --git a/lib/protocol/encryption.go b/lib/protocol/encryption.go index 0ccac6f42..f5cdebb1f 100644 --- a/lib/protocol/encryption.go +++ b/lib/protocol/encryption.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "strings" + "sync" "time" "github.com/gogo/protobuf/proto" @@ -41,11 +42,11 @@ const ( // must decrypt those and answer requests by encrypting the data. type encryptedModel struct { model Model - folderKeys map[string]*[keySize]byte // folder ID -> key + folderKeys *folderKeyRegistry } func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo) error { - if folderKey, ok := e.folderKeys[folder]; ok { + if folderKey, ok := e.folderKeys.get(folder); ok { // incoming index data to be decrypted if err := decryptFileInfos(files, folderKey); err != nil { return err @@ -55,7 +56,7 @@ func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo } func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) error { - if folderKey, ok := e.folderKeys[folder]; ok { + if folderKey, ok := e.folderKeys.get(folder); ok { // incoming index data to be decrypted if err := decryptFileInfos(files, folderKey); err != nil { return err @@ -65,7 +66,7 @@ func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []Fi } func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) { - folderKey, ok := e.folderKeys[folder] + folderKey, ok := e.folderKeys.get(folder) if !ok { return e.model.Request(deviceID, folder, name, blockNo, size, offset, hash, weakHash, fromTemporary) } @@ -123,7 +124,7 @@ func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo, } func (e encryptedModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error { - if _, ok := e.folderKeys[folder]; !ok { + if _, ok := e.folderKeys.get(folder); !ok { return e.model.DownloadProgress(deviceID, folder, updates) } @@ -135,42 +136,46 @@ func (e encryptedModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) e return e.model.ClusterConfig(deviceID, config) } -func (e encryptedModel) Closed(conn Connection, err error) { - e.model.Closed(conn, err) +func (e encryptedModel) Closed(device DeviceID, err error) { + e.model.Closed(device, err) } // The encryptedConnection sits between the model and the encrypted device. It // encrypts outgoing metadata and decrypts incoming responses. type encryptedConnection struct { ConnectionInfo - conn Connection - folderKeys map[string]*[keySize]byte // folder ID -> key + conn *rawConnection + folderKeys *folderKeyRegistry } func (e encryptedConnection) Start() { e.conn.Start() } +func (e encryptedConnection) SetFolderPasswords(passwords map[string]string) { + e.folderKeys.setPasswords(passwords) +} + func (e encryptedConnection) ID() DeviceID { return e.conn.ID() } func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error { - if folderKey, ok := e.folderKeys[folder]; ok { + if folderKey, ok := e.folderKeys.get(folder); ok { encryptFileInfos(files, folderKey) } return e.conn.Index(ctx, folder, files) } func (e encryptedConnection) IndexUpdate(ctx context.Context, folder string, files []FileInfo) error { - if folderKey, ok := e.folderKeys[folder]; ok { + if folderKey, ok := e.folderKeys.get(folder); ok { encryptFileInfos(files, folderKey) } return e.conn.IndexUpdate(ctx, folder, files) } func (e encryptedConnection) Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) { - folderKey, ok := e.folderKeys[folder] + folderKey, ok := e.folderKeys.get(folder) if !ok { return e.conn.Request(ctx, folder, name, blockNo, offset, size, hash, weakHash, fromTemporary) } @@ -205,7 +210,7 @@ func (e encryptedConnection) Request(ctx context.Context, folder string, name st } func (e encryptedConnection) DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate) { - if _, ok := e.folderKeys[folder]; !ok { + if _, ok := e.folderKeys.get(folder); !ok { e.conn.DownloadProgress(ctx, folder, updates) } @@ -590,3 +595,27 @@ func isEncryptedParentFromComponents(pathComponents []string) bool { } return true } + +type folderKeyRegistry struct { + keys map[string]*[keySize]byte // folder ID -> key + mut sync.RWMutex +} + +func newFolderKeyRegistry(passwords map[string]string) *folderKeyRegistry { + return &folderKeyRegistry{ + keys: keysFromPasswords(passwords), + } +} + +func (r *folderKeyRegistry) get(folder string) (*[keySize]byte, bool) { + r.mut.RLock() + key, ok := r.keys[folder] + r.mut.RUnlock() + return key, ok +} + +func (r *folderKeyRegistry) setPasswords(passwords map[string]string) { + r.mut.Lock() + r.keys = keysFromPasswords(passwords) + r.mut.Unlock() +} diff --git a/lib/protocol/mocks/connection.go b/lib/protocol/mocks/connection.go index c2c8cd381..a07ce4c2f 100644 --- a/lib/protocol/mocks/connection.go +++ b/lib/protocol/mocks/connection.go @@ -135,6 +135,11 @@ type Connection struct { result1 []byte result2 error } + SetFolderPasswordsStub func(map[string]string) + setFolderPasswordsMutex sync.RWMutex + setFolderPasswordsArgsForCall []struct { + arg1 map[string]string + } StartStub func() startMutex sync.RWMutex startArgsForCall []struct { @@ -817,6 +822,38 @@ func (fake *Connection) RequestReturnsOnCall(i int, result1 []byte, result2 erro }{result1, result2} } +func (fake *Connection) SetFolderPasswords(arg1 map[string]string) { + fake.setFolderPasswordsMutex.Lock() + fake.setFolderPasswordsArgsForCall = append(fake.setFolderPasswordsArgsForCall, struct { + arg1 map[string]string + }{arg1}) + stub := fake.SetFolderPasswordsStub + fake.recordInvocation("SetFolderPasswords", []interface{}{arg1}) + fake.setFolderPasswordsMutex.Unlock() + if stub != nil { + fake.SetFolderPasswordsStub(arg1) + } +} + +func (fake *Connection) SetFolderPasswordsCallCount() int { + fake.setFolderPasswordsMutex.RLock() + defer fake.setFolderPasswordsMutex.RUnlock() + return len(fake.setFolderPasswordsArgsForCall) +} + +func (fake *Connection) SetFolderPasswordsCalls(stub func(map[string]string)) { + fake.setFolderPasswordsMutex.Lock() + defer fake.setFolderPasswordsMutex.Unlock() + fake.SetFolderPasswordsStub = stub +} + +func (fake *Connection) SetFolderPasswordsArgsForCall(i int) map[string]string { + fake.setFolderPasswordsMutex.RLock() + defer fake.setFolderPasswordsMutex.RUnlock() + argsForCall := fake.setFolderPasswordsArgsForCall[i] + return argsForCall.arg1 +} + func (fake *Connection) Start() { fake.startMutex.Lock() fake.startArgsForCall = append(fake.startArgsForCall, struct { @@ -1080,6 +1117,8 @@ func (fake *Connection) Invocations() map[string][][]interface{} { defer fake.remoteAddrMutex.RUnlock() fake.requestMutex.RLock() defer fake.requestMutex.RUnlock() + fake.setFolderPasswordsMutex.RLock() + defer fake.setFolderPasswordsMutex.RUnlock() fake.startMutex.RLock() defer fake.startMutex.RUnlock() fake.statisticsMutex.RLock() diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 284df12d1..188822d37 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -126,8 +126,8 @@ type Model interface { Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) // A cluster configuration message was received ClusterConfig(deviceID DeviceID, config ClusterConfig) error - // The peer device closed the connection - Closed(conn Connection, err error) + // The peer device closed the connection or an error occurred + Closed(device DeviceID, err error) // The peer device sent progress updates for the files it is currently downloading DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error } @@ -140,6 +140,7 @@ type RequestResponse interface { type Connection interface { Start() + SetFolderPasswords(passwords map[string]string) Close(err error) ID() DeviceID Index(ctx context.Context, folder string, files []FileInfo) error @@ -225,24 +226,16 @@ const ( // Should not be modified in production code, just for testing. var CloseTimeout = 10 * time.Second -func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection { - receiver = nativeModel{receiver} - rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress) - return wireFormatConnection{rc} -} - -func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection { - keys := keysFromPasswords(passwords) - +func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression, passwords map[string]string) Connection { // Encryption / decryption is first (outermost) before conversion to // native path formats. nm := nativeModel{receiver} - em := encryptedModel{model: nm, folderKeys: keys} + em := &encryptedModel{model: nm, folderKeys: newFolderKeyRegistry(passwords)} // We do the wire format conversion first (outermost) so that the // metadata is in wire format when it reaches the encryption step. rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress) - ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys} + ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: em.folderKeys} wc := wireFormatConnection{ec} return wc @@ -748,6 +741,8 @@ func (c *rawConnection) writerLoop() { } func (c *rawConnection) writeMessage(msg message) error { + msgContext, _ := messageContext(msg) + l.Debugf("Writing %v", msgContext) if c.shouldCompressMessage(msg) { return c.writeCompressedMessage(msg) } @@ -955,7 +950,7 @@ func (c *rawConnection) internalClose(err error) { <-c.dispatcherLoopStopped - c.receiver.Closed(c, err) + c.receiver.Closed(c.ID(), err) }) } diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index cfff22d77..a82cbb961 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -31,10 +31,10 @@ func TestPing(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil)) c0.Start() defer closeAndWait(c0, ar, bw) - c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c1 := getRawConnection(NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil)) c1.Start() defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) @@ -57,10 +57,10 @@ func TestClose(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways, nil)) c0.Start() defer closeAndWait(c0, ar, bw) - c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways) + c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways, nil) c1.Start() defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) @@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) { m := newTestModel() rw := testutils.NewBlockingRW() - c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil)) c.Start() defer closeAndWait(c, rw) @@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever).(wireFormatConnection).Connection.(*rawConnection) + c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever, nil)) c0.Start() defer closeAndWait(c0, ar, bw) - c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever) + c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever, nil) c1.Start() defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) @@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) { m := newTestModel() rw := testutils.NewBlockingRW() - c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil)) c.Start() defer closeAndWait(c, rw) @@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) { m := newTestModel() rw := testutils.NewBlockingRW() - c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil)) c.Start() defer closeAndWait(c, rw) @@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) { m := newTestModel() rw := testutils.NewBlockingRW() - c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil)) c.Start() defer closeAndWait(c, rw) @@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) { // the model callbacks (ClusterConfig). m := newTestModel() rw := testutils.NewBlockingRW() - c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil)) m.ccFn = func(devID DeviceID, cc ClusterConfig) { c.Close(errManual) } @@ -962,17 +962,28 @@ func TestIndexIDString(t *testing.T) { } } -func closeAndWait(c Connection, closers ...io.Closer) { +func closeAndWait(c interface{}, closers ...io.Closer) { for _, closer := range closers { closer.Close() } var raw *rawConnection switch i := c.(type) { - case wireFormatConnection: - raw = i.Connection.(*rawConnection) case *rawConnection: raw = i + default: + raw = getRawConnection(c.(Connection)) } raw.internalClose(ErrClosed) raw.loopWG.Wait() } + +func getRawConnection(c Connection) *rawConnection { + var raw *rawConnection + switch i := c.(type) { + case wireFormatConnection: + raw = i.Connection.(encryptedConnection).conn + case encryptedConnection: + raw = i.conn + } + return raw +}