From 482795bab0b9a15cfc41316d5cb2c07b35e4c852 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Sun, 4 May 2014 17:40:40 +0200 Subject: [PATCH] Streamline error handling and locking, with fix for close() race --- protocol/protocol.go | 238 +++++++++++++++++++++---------------------- 1 file changed, 118 insertions(+), 120 deletions(-) diff --git a/protocol/protocol.go b/protocol/protocol.go index 889199591..f8fffa2fe 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -64,24 +64,24 @@ type Connection interface { } type rawConnection struct { - sync.RWMutex + id string + receiver Model + + reader io.ReadCloser + cr *countingReader + xr *xdr.Reader + writer io.WriteCloser + + cw *countingWriter + wb *bufio.Writer + xw *xdr.Writer + wmut sync.Mutex + closed bool - id string - receiver Model - reader io.ReadCloser - cr *countingReader - xr *xdr.Reader - writer io.WriteCloser - cw *countingWriter - wb *bufio.Writer - xw *xdr.Writer - closed chan struct{} awaiting map[int]chan asyncResult nextID int indexSent map[string]map[string][2]int64 - - hasSentIndex bool - hasRecvdIndex bool + imut sync.Mutex } type asyncResult struct { @@ -115,7 +115,6 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M cw: cw, wb: wb, xw: xdr.NewWriter(wb), - closed: make(chan struct{}), awaiting: make(map[int]chan asyncResult), indexSent: make(map[string]map[string][2]int64), } @@ -132,11 +131,11 @@ func (c *rawConnection) ID() string { // Index writes the list of file information to the connected peer node func (c *rawConnection) Index(repo string, idx []FileInfo) { - c.Lock() if c.isClosed() { - c.Unlock() return } + + c.imut.Lock() var msgType int if c.indexSent[repo] == nil { // This is the first time we send an index. @@ -159,14 +158,15 @@ func (c *rawConnection) Index(repo string, idx []FileInfo) { idx = diff } - header{0, c.nextID, msgType}.encodeXDR(c.xw) - _, err := IndexMessage{repo, idx}.encodeXDR(c.xw) - if err == nil { - err = c.flush() - } + id := c.nextID c.nextID = (c.nextID + 1) & 0xfff - c.hasSentIndex = true - c.Unlock() + c.imut.Unlock() + + c.wmut.Lock() + header{0, id, msgType}.encodeXDR(c.xw) + IndexMessage{repo, idx}.encodeXDR(c.xw) + err := c.flush() + c.wmut.Unlock() if err != nil { c.close(err) @@ -176,28 +176,30 @@ func (c *rawConnection) Index(repo string, idx []FileInfo) { // Request returns the bytes for the specified block after fetching them from the connected peer. func (c *rawConnection) Request(repo string, name string, offset int64, size int) ([]byte, error) { - c.Lock() if c.isClosed() { - c.Unlock() return nil, ErrClosed } + + c.imut.Lock() + id := c.nextID + c.nextID = (c.nextID + 1) & 0xfff rc := make(chan asyncResult) - if _, ok := c.awaiting[c.nextID]; ok { + if _, ok := c.awaiting[id]; ok { panic("id taken") } - c.awaiting[c.nextID] = rc - header{0, c.nextID, messageTypeRequest}.encodeXDR(c.xw) - _, err := RequestMessage{repo, name, uint64(offset), uint32(size)}.encodeXDR(c.xw) - if err == nil { - err = c.flush() - } + c.awaiting[id] = rc + c.imut.Unlock() + + c.wmut.Lock() + header{0, id, messageTypeRequest}.encodeXDR(c.xw) + RequestMessage{repo, name, uint64(offset), uint32(size)}.encodeXDR(c.xw) + err := c.flush() + c.wmut.Unlock() + if err != nil { - c.Unlock() c.close(err) return nil, err } - c.nextID = (c.nextID + 1) & 0xfff - c.Unlock() res, ok := <-rc if !ok { @@ -208,46 +210,47 @@ func (c *rawConnection) Request(repo string, name string, offset int64, size int // ClusterConfig send the cluster configuration message to the peer and returns any error func (c *rawConnection) ClusterConfig(config ClusterConfigMessage) { - c.Lock() - defer c.Unlock() - if c.isClosed() { return } - header{0, c.nextID, messageTypeClusterConfig}.encodeXDR(c.xw) + c.imut.Lock() + id := c.nextID c.nextID = (c.nextID + 1) & 0xfff + c.imut.Unlock() + + c.wmut.Lock() + header{0, id, messageTypeClusterConfig}.encodeXDR(c.xw) + config.encodeXDR(c.xw) + err := c.flush() + c.wmut.Unlock() - _, err := config.encodeXDR(c.xw) - if err == nil { - err = c.flush() - } if err != nil { c.close(err) } } func (c *rawConnection) ping() bool { - c.Lock() if c.isClosed() { - c.Unlock() return false } + + c.imut.Lock() + id := c.nextID + c.nextID = (c.nextID + 1) & 0xfff rc := make(chan asyncResult, 1) - c.awaiting[c.nextID] = rc - header{0, c.nextID, messageTypePing}.encodeXDR(c.xw) + c.awaiting[id] = rc + c.imut.Unlock() + + c.wmut.Lock() + header{0, id, messageTypePing}.encodeXDR(c.xw) err := c.flush() + c.wmut.Unlock() + if err != nil { - c.Unlock() c.close(err) return false - } else if c.xw.Error() != nil { - c.Unlock() - c.close(c.xw.Error()) - return false } - c.nextID = (c.nextID + 1) & 0xfff - c.Unlock() res, ok := <-rc return ok && res.err == nil @@ -258,40 +261,47 @@ type flusher interface { } func (c *rawConnection) flush() error { - c.wb.Flush() + if err := c.xw.Error(); err != nil { + return err + } + + if err := c.wb.Flush(); err != nil { + return err + } + if f, ok := c.writer.(flusher); ok { return f.Flush() } + return nil } func (c *rawConnection) close(err error) { - c.Lock() - select { - case <-c.closed: - c.Unlock() + c.imut.Lock() + c.wmut.Lock() + defer c.imut.Unlock() + defer c.wmut.Unlock() + + if c.closed { return - default: } - close(c.closed) + + c.closed = true + for _, ch := range c.awaiting { close(ch) } c.awaiting = nil c.writer.Close() c.reader.Close() - c.Unlock() c.receiver.Close(c.id, err) } func (c *rawConnection) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } + c.wmut.Lock() + defer c.wmut.Unlock() + return c.closed } func (c *rawConnection) readerLoop() { @@ -299,8 +309,8 @@ loop: for !c.isClosed() { var hdr header hdr.decodeXDR(c.xr) - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } if hdr.version != 0 { @@ -312,8 +322,8 @@ loop: case messageTypeIndex: var im IndexMessage im.decodeXDR(c.xr) - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } else { @@ -326,15 +336,12 @@ loop: go c.receiver.Index(c.id, im.Repository, im.Files) } - c.Lock() - c.hasRecvdIndex = true - c.Unlock() case messageTypeIndexUpdate: var im IndexMessage im.decodeXDR(c.xr) - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } else { go c.receiver.IndexUpdate(c.id, im.Repository, im.Files) @@ -343,8 +350,8 @@ loop: case messageTypeRequest: var req RequestMessage req.decodeXDR(c.xr) - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } go c.processRequest(hdr.msgID, req) @@ -352,16 +359,16 @@ loop: case messageTypeResponse: data := c.xr.ReadBytesMax(256 * 1024) // Sufficiently larger than max expected block size - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } go func(hdr header, err error) { - c.Lock() + c.imut.Lock() rc, ok := c.awaiting[hdr.msgID] delete(c.awaiting, hdr.msgID) - c.Unlock() + c.imut.Unlock() if ok { rc <- asyncResult{data, err} @@ -370,37 +377,34 @@ loop: }(hdr, c.xr.Error()) case messageTypePing: - c.Lock() + c.wmut.Lock() header{0, hdr.msgID, messageTypePong}.encodeXDR(c.xw) err := c.flush() - c.Unlock() + c.wmut.Unlock() if err != nil { c.close(err) break loop - } else if c.xw.Error() != nil { - c.close(c.xw.Error()) - break loop } case messageTypePong: - c.RLock() + c.imut.Lock() rc, ok := c.awaiting[hdr.msgID] - c.RUnlock() if ok { - rc <- asyncResult{} - close(rc) + go func() { + rc <- asyncResult{} + close(rc) + }() - c.Lock() delete(c.awaiting, hdr.msgID) - c.Unlock() } + c.imut.Unlock() case messageTypeClusterConfig: var cm ClusterConfigMessage cm.decodeXDR(c.xr) - if c.xr.Error() != nil { - c.close(c.xr.Error()) + if err := c.xr.Error(); err != nil { + c.close(err) break loop } else { go c.receiver.ClusterConfig(c.id, cm) @@ -416,15 +420,14 @@ loop: func (c *rawConnection) processRequest(msgID int, req RequestMessage) { data, _ := c.receiver.Request(c.id, req.Repository, req.Name, int64(req.Offset), int(req.Size)) - c.Lock() + c.wmut.Lock() header{0, msgID, messageTypeResponse}.encodeXDR(c.xw) - _, err := c.xw.WriteBytes(data) - if err == nil { - err = c.flush() - } - c.Unlock() + c.xw.WriteBytes(data) + err := c.flush() + c.wmut.Unlock() buffers.Put(data) + if err != nil { c.close(err) } @@ -434,27 +437,22 @@ func (c *rawConnection) pingerLoop() { var rc = make(chan bool, 1) ticker := time.Tick(pingIdleTime / 2) for { + if c.isClosed() { + return + } select { case <-ticker: - c.RLock() - ready := c.hasRecvdIndex && c.hasSentIndex - c.RUnlock() - - if ready { - go func() { - rc <- c.ping() - }() - select { - case ok := <-rc: - if !ok { - c.close(fmt.Errorf("ping failure")) - } - case <-time.After(pingTimeout): - c.close(fmt.Errorf("ping timeout")) + go func() { + rc <- c.ping() + }() + select { + case ok := <-rc: + if !ok { + c.close(fmt.Errorf("ping failure")) } + case <-time.After(pingTimeout): + c.close(fmt.Errorf("ping timeout")) } - case <-c.closed: - return } } }