lib: Close underlying conn in protocol (fixes #7165) (#7212)

This commit is contained in:
Simon Frei 2020-12-21 11:40:51 +01:00 committed by GitHub
parent 4a787986cd
commit c845e245a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 124 deletions

View File

@ -11,7 +11,6 @@ import (
"net"
"time"
"github.com/syncthing/syncthing/lib/connections"
"github.com/syncthing/syncthing/lib/db"
"github.com/syncthing/syncthing/lib/model"
"github.com/syncthing/syncthing/lib/protocol"
@ -114,7 +113,7 @@ func (m *mockedModel) ScanFolderSubdirs(folder string, subs []string) error {
func (m *mockedModel) BringToFront(folder, file string) {}
func (m *mockedModel) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) {
func (m *mockedModel) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
return nil, false
}
@ -165,7 +164,7 @@ func (m *mockedModel) DownloadProgress(deviceID protocol.DeviceID, folder string
return nil
}
func (m *mockedModel) AddConnection(conn connections.Connection, hello protocol.Hello) {}
func (m *mockedModel) AddConnection(conn protocol.Connection, hello protocol.Hello) {}
func (m *mockedModel) OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error {
return nil

View File

@ -329,15 +329,14 @@ func (s *service) handle(ctx context.Context) error {
var protoConn protocol.Connection
passwords := s.cfg.FolderPasswords(remoteID)
if len(passwords) > 0 {
protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression)
protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
} else {
protoConn = protocol.NewConnection(remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression)
protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
}
modelConn := completeConn{c, protoConn}
l.Infof("Established secure connection to %s at %s", remoteID, c)
s.model.AddConnection(modelConn, hello)
s.model.AddConnection(protoConn, hello)
continue
}
}

View File

@ -22,31 +22,6 @@ import (
"github.com/thejerf/suture/v4"
)
// Connection is what we expose to the outside. It is a protocol.Connection
// that can be closed and has some metadata.
type Connection interface {
protocol.Connection
Type() string
Transport() string
RemoteAddr() net.Addr
Priority() int
String() string
Crypto() string
}
// completeConn is the aggregation of an internalConn and the
// protocol.Connection running on top of it. It implements the Connection
// interface.
type completeConn struct {
internalConn
protocol.Connection
}
func (c completeConn) Close(err error) {
c.Connection.Close(err)
c.internalConn.Close()
}
type tlsConn interface {
io.ReadWriteCloser
ConnectionState() tls.ConnectionState
@ -107,12 +82,12 @@ func (t connType) Transport() string {
}
}
func (c internalConn) Close() {
func (c internalConn) Close() error {
// *tls.Conn.Close() does more than it says on the tin. Specifically, it
// sends a TLS alert message, which might block forever if the
// connection is dead and we don't have a deadline set.
_ = c.SetWriteDeadline(time.Now().Add(250 * time.Millisecond))
_ = c.tlsConn.Close()
return c.tlsConn.Close()
}
func (c internalConn) Type() string {
@ -203,8 +178,8 @@ type genericListener interface {
type Model interface {
protocol.Model
AddConnection(conn Connection, hello protocol.Hello)
Connection(remoteID protocol.DeviceID) (Connection, bool)
AddConnection(conn protocol.Connection, hello protocol.Hello)
Connection(remoteID protocol.DeviceID) (protocol.Connection, bool)
OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error
GetHello(protocol.DeviceID) protocol.HelloIntf
}

View File

@ -9,13 +9,12 @@ package model
import (
"bytes"
"context"
"net"
"sync"
"time"
"github.com/syncthing/syncthing/lib/connections"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/scanner"
"github.com/syncthing/syncthing/lib/testutils"
)
type downloadProgressMessage struct {
@ -24,7 +23,7 @@ type downloadProgressMessage struct {
}
type fakeConnection struct {
fakeUnderlyingConn
testutils.FakeConnectionInfo
id protocol.DeviceID
downloadProgressMessages []downloadProgressMessage
closed bool
@ -219,50 +218,3 @@ func addFakeConn(m *testModel, dev protocol.DeviceID) *fakeConnection {
return fc
}
type fakeProtoConn struct {
protocol.Connection
fakeUnderlyingConn
}
func newFakeProtoConn(protoConn protocol.Connection) connections.Connection {
return &fakeProtoConn{Connection: protoConn}
}
// fakeUnderlyingConn implements the methods of connections.Connection that are
// not implemented by protocol.Connection
type fakeUnderlyingConn struct{}
func (f *fakeUnderlyingConn) RemoteAddr() net.Addr {
return &fakeAddr{}
}
func (f *fakeUnderlyingConn) Type() string {
return "fake"
}
func (f *fakeUnderlyingConn) Crypto() string {
return "fake"
}
func (f *fakeUnderlyingConn) Transport() string {
return "fake"
}
func (f *fakeUnderlyingConn) Priority() int {
return 9000
}
func (f *fakeUnderlyingConn) String() string {
return ""
}
type fakeAddr struct{}
func (fakeAddr) Network() string {
return "network"
}
func (fakeAddr) String() string {
return "address"
}

View File

@ -150,7 +150,7 @@ type model struct {
// fields protected by pmut
pmut sync.RWMutex
conn map[protocol.DeviceID]connections.Connection
conn map[protocol.DeviceID]protocol.Connection
connRequestLimiters map[protocol.DeviceID]*byteSemaphore
closed map[protocol.DeviceID]chan struct{}
helloMessages map[protocol.DeviceID]protocol.Hello
@ -232,7 +232,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
// fields protected by pmut
pmut: sync.NewRWMutex(),
conn: make(map[protocol.DeviceID]connections.Connection),
conn: make(map[protocol.DeviceID]protocol.Connection),
connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
closed: make(map[protocol.DeviceID]chan struct{}),
helloMessages: make(map[protocol.DeviceID]protocol.Hello),
@ -1660,7 +1660,7 @@ func (m *model) Closed(conn protocol.Connection, err error) {
m.progressEmitter.temporaryIndexUnsubscribe(conn)
l.Infof("Connection to %s at %s closed: %v", device, conn.Name(), err)
l.Infof("Connection to %s at %s closed: %v", device, conn, err)
m.evLogger.Log(events.DeviceDisconnected, map[string]string{
"id": device.String(),
"error": err.Error(),
@ -1912,7 +1912,7 @@ func (m *model) CurrentGlobalFile(folder string, file string) (protocol.FileInfo
}
// Connection returns the current connection for device, and a boolean whether a connection was found.
func (m *model) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) {
func (m *model) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
m.pmut.RLock()
cn, ok := m.conn[deviceID]
m.pmut.RUnlock()
@ -2039,7 +2039,7 @@ func (m *model) GetHello(id protocol.DeviceID) protocol.HelloIntf {
// AddConnection adds a new peer connection to the model. An initial index will
// be sent to the connected peer, thereafter index updates whenever the local
// folder changes.
func (m *model) AddConnection(conn connections.Connection, hello protocol.Hello) {
func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
deviceID := conn.ID()
device, ok := m.cfg.Device(deviceID)
if !ok {

View File

@ -3297,7 +3297,7 @@ func TestConnCloseOnRestart(t *testing.T) {
br := &testutils.BlockingRW{}
nw := &testutils.NoopRW{}
m.AddConnection(newFakeProtoConn(protocol.NewConnection(device1, br, nw, m, "testConn", protocol.CompressionNever)), protocol.Hello{})
m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"fc"}, protocol.CompressionNever), protocol.Hello{})
m.pmut.RLock()
if len(m.closed) != 1 {
t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn))

View File

@ -10,6 +10,7 @@ import (
"testing"
"github.com/syncthing/syncthing/lib/dialer"
"github.com/syncthing/syncthing/lib/testutils"
)
func BenchmarkRequestsRawTCP(b *testing.B) {
@ -59,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) {
func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) {
// Start up Connections on them
c0 := NewConnection(LocalDeviceID, conn0, conn0, new(fakeModel), "c0", CompressionMetadata)
c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c0"}, CompressionMetadata)
c0.Start()
c1 := NewConnection(LocalDeviceID, conn1, conn1, new(fakeModel), "c1", CompressionMetadata)
c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c1"}, CompressionMetadata)
c1.Start()
// Satisfy the assertions in the protocol by sending an initial cluster config

View File

@ -128,6 +128,7 @@ func (e encryptedModel) Closed(conn Connection, err error) {
// The encryptedConnection sits between the model and the encrypted device. It
// encrypts outgoing metadata and decrypts incoming responses.
type encryptedConnection struct {
ConnectionInfo
conn Connection
folderKeys map[string]*[keySize]byte // folder ID -> key
}
@ -140,10 +141,6 @@ func (e encryptedConnection) ID() DeviceID {
return e.conn.ID()
}
func (e encryptedConnection) Name() string {
return e.conn.Name()
}
func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok {
encryptFileInfos(files, folderKey)

View File

@ -8,6 +8,7 @@ import (
"encoding/binary"
"fmt"
"io"
"net"
"path"
"strings"
"sync"
@ -134,7 +135,6 @@ type Connection interface {
Start()
Close(err error)
ID() DeviceID
Name() string
Index(ctx context.Context, folder string, files []FileInfo) error
IndexUpdate(ctx context.Context, folder string, files []FileInfo) error
Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error)
@ -142,16 +142,28 @@ type Connection interface {
DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate)
Statistics() Statistics
Closed() bool
ConnectionInfo
}
type ConnectionInfo interface {
Type() string
Transport() string
RemoteAddr() net.Addr
Priority() int
String() string
Crypto() string
}
type rawConnection struct {
ConnectionInfo
id DeviceID
name string
receiver Model
startTime time.Time
cr *countingReader
cw *countingWriter
cr *countingReader
cw *countingWriter
closer io.Closer // Closing the underlying connection and thus cr and cw
awaiting map[int]chan asyncResult
awaitingMut sync.Mutex
@ -205,13 +217,13 @@ const (
// Should not be modified in production code, just for testing.
var CloseTimeout = 10 * time.Second
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
receiver = nativeModel{receiver}
rc := newRawConnection(deviceID, reader, writer, receiver, name, compress)
rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress)
return wireFormatConnection{rc}
}
func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
keys := keysFromPasswords(passwords)
// Encryption / decryption is first (outermost) before conversion to
@ -221,23 +233,24 @@ func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, read
// We do the wire format conversion first (outermost) so that the
// metadata is in wire format when it reaches the encryption step.
rc := newRawConnection(deviceID, reader, writer, em, name, compress)
ec := encryptedConnection{conn: rc, folderKeys: keys}
rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress)
ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys}
wc := wireFormatConnection{ec}
return wc
}
func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) *rawConnection {
func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) *rawConnection {
cr := &countingReader{Reader: reader}
cw := &countingWriter{Writer: writer}
return &rawConnection{
ConnectionInfo: connInfo,
id: deviceID,
name: name,
receiver: receiver,
cr: cr,
cw: cw,
closer: closer,
awaiting: make(map[int]chan asyncResult),
inbox: make(chan message),
outbox: make(chan asyncMessage),
@ -282,10 +295,6 @@ func (c *rawConnection) ID() DeviceID {
return c.id
}
func (c *rawConnection) Name() string {
return c.name
}
// Index writes the list of file information to the connected peer device
func (c *rawConnection) Index(ctx context.Context, folder string, idx []FileInfo) error {
select {
@ -931,6 +940,9 @@ func (c *rawConnection) Close(err error) {
func (c *rawConnection) internalClose(err error) {
c.closeOnce.Do(func() {
l.Debugln("close due to", err)
if cerr := c.closer.Close(); cerr != nil {
l.Debugln(c.id, "failed to close underlying conn:", cerr)
}
close(c.closed)
c.awaitingMut.Lock()

View File

@ -31,10 +31,10 @@ func TestPing(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -57,10 +57,10 @@ func TestClose(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"name"}, CompressionAlways)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"c0"}, CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"c1"}, CompressionNever)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
// the model callbacks (ClusterConfig).
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
m.ccFn = func(devID DeviceID, cc ClusterConfig) {
c.Close(errManual)
}

View File

@ -8,6 +8,7 @@ package testutils
import (
"errors"
"net"
"sync"
)
@ -52,3 +53,49 @@ func (rw *NoopRW) Read(p []byte) (n int, err error) {
func (rw *NoopRW) Write(p []byte) (n int, err error) {
return len(p), nil
}
type NoopCloser struct{}
func (NoopCloser) Close() error {
return nil
}
// FakeConnectionInfo implements the methods of protocol.Connection that are
// not implemented by protocol.Connection
type FakeConnectionInfo struct {
Name string
}
func (f *FakeConnectionInfo) RemoteAddr() net.Addr {
return &FakeAddr{}
}
func (f *FakeConnectionInfo) Type() string {
return "fake"
}
func (f *FakeConnectionInfo) Crypto() string {
return "fake"
}
func (f *FakeConnectionInfo) Transport() string {
return "fake"
}
func (f *FakeConnectionInfo) Priority() int {
return 9000
}
func (f *FakeConnectionInfo) String() string {
return ""
}
type FakeAddr struct{}
func (FakeAddr) Network() string {
return "network"
}
func (FakeAddr) String() string {
return "address"
}