diff --git a/protocol.go b/protocol.go index b9859203a..605de4781 100644 --- a/protocol.go +++ b/protocol.go @@ -31,8 +31,7 @@ const ( const ( stateInitial = iota - stateCCRcvd - stateIdxRcvd + stateReady ) // FileInfo flags @@ -103,7 +102,6 @@ type rawConnection struct { id DeviceID name string receiver Model - state int cr *countingReader cw *countingWriter @@ -155,7 +153,6 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv id: deviceID, name: name, receiver: nativeModel{receiver}, - state: stateInitial, cr: cr, cw: cw, outbox: make(chan hdrMsg), @@ -285,6 +282,7 @@ func (c *rawConnection) readerLoop() (err error) { c.close(err) }() + state := stateInitial for { select { case <-c.closed: @@ -298,47 +296,54 @@ func (c *rawConnection) readerLoop() (err error) { } switch msg := msg.(type) { + case ClusterConfigMessage: + if state != stateInitial { + return fmt.Errorf("protocol error: cluster config message in state %d", state) + } + go c.receiver.ClusterConfig(c.id, msg) + state = stateReady + case IndexMessage: switch hdr.msgType { case messageTypeIndex: - if c.state < stateCCRcvd { - return fmt.Errorf("protocol error: index message in state %d", c.state) + if state != stateReady { + return fmt.Errorf("protocol error: index message in state %d", state) } c.handleIndex(msg) - c.state = stateIdxRcvd + state = stateReady case messageTypeIndexUpdate: - if c.state < stateIdxRcvd { - return fmt.Errorf("protocol error: index update message in state %d", c.state) + if state != stateReady { + return fmt.Errorf("protocol error: index update message in state %d", state) } c.handleIndexUpdate(msg) + state = stateReady } case RequestMessage: - if c.state < stateIdxRcvd { - return fmt.Errorf("protocol error: request message in state %d", c.state) + if state != stateReady { + return fmt.Errorf("protocol error: request message in state %d", state) } // Requests are handled asynchronously go c.handleRequest(hdr.msgID, msg) case ResponseMessage: - if c.state < stateIdxRcvd { - return fmt.Errorf("protocol error: response message in state %d", c.state) + if state != stateReady { + return fmt.Errorf("protocol error: response message in state %d", state) } c.handleResponse(hdr.msgID, msg) case pingMessage: + if state != stateReady { + return fmt.Errorf("protocol error: ping message in state %d", state) + } c.send(hdr.msgID, messageTypePong, pongMessage{}) case pongMessage: - c.handlePong(hdr.msgID) - - case ClusterConfigMessage: - if c.state != stateInitial { - return fmt.Errorf("protocol error: cluster config message in state %d", c.state) + if state != stateReady { + return fmt.Errorf("protocol error: pong message in state %d", state) } - go c.receiver.ClusterConfig(c.id, msg) - c.state = stateCCRcvd + c.handlePong(hdr.msgID) case CloseMessage: return errors.New(msg.Reason) diff --git a/protocol_test.go b/protocol_test.go index 3ff1042cb..bb4fe7d95 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -67,8 +67,10 @@ func TestPing(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - c1 := NewConnection(c1ID, br, aw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + c0.ClusterConfig(ClusterConfigMessage{}) + c1.ClusterConfig(ClusterConfigMessage{}) if ok := c0.ping(); !ok { t.Error("c0 ping failed") @@ -81,8 +83,8 @@ func TestPing(t *testing.T) { func TestPingErr(t *testing.T) { e := errors.New("something broke") - for i := 0; i < 16; i++ { - for j := 0; j < 16; j++ { + for i := 0; i < 32; i++ { + for j := 0; j < 32; j++ { m0 := newTestModel() m1 := newTestModel() @@ -92,12 +94,16 @@ func TestPingErr(t *testing.T) { ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, eaw, m1, "name", CompressAlways) + c1 := NewConnection(c1ID, br, eaw, m1, "name", CompressAlways) + c0.ClusterConfig(ClusterConfigMessage{}) + c1.ClusterConfig(ClusterConfigMessage{}) res := c0.ping() if (i < 8 || j < 8) && res { + // This should have resulted in failure, as there is no way an empty ClusterConfig plus a Ping message fits in eight bytes. t.Errorf("Unexpected ping success; i=%d, j=%d", i, j) - } else if (i >= 12 && j >= 12) && !res { + } else if (i >= 28 && j >= 28) && !res { + // This should have worked though, as 28 bytes is plenty for both. t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j) } } @@ -168,7 +174,9 @@ func TestVersionErr(t *testing.T) { br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c0.ClusterConfig(ClusterConfigMessage{}) + c1.ClusterConfig(ClusterConfigMessage{}) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ @@ -191,7 +199,9 @@ func TestTypeErr(t *testing.T) { br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c0.ClusterConfig(ClusterConfigMessage{}) + c1.ClusterConfig(ClusterConfigMessage{}) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ @@ -214,7 +224,9 @@ func TestClose(t *testing.T) { br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways) + c0.ClusterConfig(ClusterConfigMessage{}) + c1.ClusterConfig(ClusterConfigMessage{}) c0.close(nil)