diff --git a/main.go b/main.go index e35884fce..2b2d78463 100644 --- a/main.go +++ b/main.go @@ -496,7 +496,7 @@ func saveIndex(m *Model) { gzw := gzip.NewWriter(idxf) - protocol.WriteIndex(gzw, m.ProtocolIndex()) + protocol.WriteIndex(gzw, "local", m.ProtocolIndex()) gzw.Close() idxf.Close() os.Rename(fullName+".tmp", fullName) @@ -516,8 +516,8 @@ func loadIndex(m *Model) { } defer gzr.Close() - idx, err := protocol.ReadIndex(gzr) - if err != nil { + repo, idx, err := protocol.ReadIndex(gzr) + if repo != "local" || err != nil { return } m.SeedLocal(idx) diff --git a/model.go b/model.go index 80af8252b..003acf57e 100644 --- a/model.go +++ b/model.go @@ -55,8 +55,8 @@ type Model struct { type Connection interface { ID() string - Index([]protocol.FileInfo) - Request(name string, offset int64, size uint32, hash []byte) ([]byte, error) + Index(string, []protocol.FileInfo) + Request(repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) Statistics() protocol.Statistics Option(key string) string } @@ -360,6 +360,8 @@ func (m *Model) Close(node string, err error) { } if err == protocol.ErrClusterHash { warnf("Connection to %s closed due to mismatched cluster hash. Ensure that the configured cluster members are identical on both nodes.", node) + } else if err != io.EOF { + warnf("Connection to %s closed: %v", node, err) } m.fq.RemoveAvailable(node) @@ -385,7 +387,7 @@ func (m *Model) Close(node string, err error) { // Request returns the specified data segment by reading it from local disk. // Implements the protocol.Model interface. -func (m *Model) Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error) { +func (m *Model) Request(nodeID, repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) { // Verify that the requested file exists in the local and global model. m.lmut.RLock() lf, localOk := m.local[name] @@ -507,7 +509,7 @@ func (m *Model) AddConnection(rawConn io.Closer, protoConn Connection) { go func() { idx := m.ProtocolIndex() - protoConn.Index(idx) + protoConn.Index("default", idx) }() m.initmut.Lock() @@ -539,7 +541,7 @@ func (m *Model) AddConnection(rawConn io.Closer, protoConn Connection) { if m.trace["pull"] { debugln("PULL: Request", nodeID, i, qb.name, qb.block.Offset) } - data, _ := protoConn.Request(qb.name, qb.block.Offset, qb.block.Size, qb.block.Hash) + data, _ := protoConn.Request("default", qb.name, qb.block.Offset, qb.block.Size, qb.block.Hash) m.fq.Done(qb.name, qb.block.Offset, data) } else { time.Sleep(1 * time.Second) @@ -585,7 +587,7 @@ func (m *Model) requestGlobal(nodeID, name string, offset int64, size uint32, ha debugf("NET REQ(out): %s: %q o=%d s=%d h=%x", nodeID, name, offset, size, hash) } - return nc.Request(name, offset, size, hash) + return nc.Request("default", name, offset, size, hash) } func (m *Model) broadcastIndexLoop() { @@ -613,7 +615,7 @@ func (m *Model) broadcastIndexLoop() { debugf("NET IDX(out/loop): %s: %d files", node.ID(), len(idx)) } go func() { - node.Index(idx) + node.Index("default", idx) indexWg.Done() }() } diff --git a/model_test.go b/model_test.go index ef30b59d8..43909f59b 100644 --- a/model_test.go +++ b/model_test.go @@ -345,7 +345,7 @@ func TestRequest(t *testing.T) { fs, _ := m.Walk(false) m.ReplaceLocal(fs) - bs, err := m.Request("some node", "foo", 0, 6, nil) + bs, err := m.Request("some node", "default", "foo", 0, 6, nil) if err != nil { t.Fatal(err) } @@ -353,7 +353,7 @@ func TestRequest(t *testing.T) { t.Errorf("Incorrect data from request: %q", string(bs)) } - bs, err = m.Request("some node", "../walk.go", 0, 6, nil) + bs, err = m.Request("some node", "default", "../walk.go", 0, 6, nil) if err == nil { t.Error("Unexpected nil error on insecure file read") } @@ -487,9 +487,9 @@ func (f FakeConnection) Option(string) string { return "" } -func (FakeConnection) Index([]protocol.FileInfo) {} +func (FakeConnection) Index(string, []protocol.FileInfo) {} -func (f FakeConnection) Request(name string, offset int64, size uint32, hash []byte) ([]byte, error) { +func (f FakeConnection) Request(repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) { return f.requestData, nil } diff --git a/protocol/PROTOCOL.md b/protocol/PROTOCOL.md index aed7823df..885f753f4 100644 --- a/protocol/PROTOCOL.md +++ b/protocol/PROTOCOL.md @@ -84,6 +84,7 @@ an empty Index message must be sent. There is no response to the Index message. struct IndexMessage { + string Repository<>; FileInfo Files<>; } @@ -100,6 +101,10 @@ message. opaque Hash<> } +The Repository field identifies the repository that the index message +pertains to. For single repository implementations an empty repository +ID is acceptable. + The file name is the part relative to the repository root. The modification time is expressed as the number of seconds since the Unix Epoch. The version field is a counter that increments each time the file @@ -143,6 +148,7 @@ before transmitting data. Each Request message must be met with a Response message. struct RequestMessage { + string Repository<>; string Name<>; unsigned hyper Offset; unsigned int Length; @@ -248,4 +254,3 @@ their repository contents and transmits an updated Index message (10). Both peers enter idle state after 10. At some later time 11, peer A determines that it has not seen data from B for some time and sends a Ping request. A response is sent at 12. - diff --git a/protocol/common_test.go b/protocol/common_test.go index a76e02633..9a1b0e457 100644 --- a/protocol/common_test.go +++ b/protocol/common_test.go @@ -4,6 +4,7 @@ import "io" type TestModel struct { data []byte + repo string name string offset int64 size uint32 @@ -17,7 +18,8 @@ func (t *TestModel) Index(nodeID string, files []FileInfo) { func (t *TestModel) IndexUpdate(nodeID string, files []FileInfo) { } -func (t *TestModel) Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error) { +func (t *TestModel) Request(nodeID, repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) { + t.repo = repo t.name = name t.offset = offset t.size = size diff --git a/protocol/marshal.go b/protocol/marshal.go index c1bce9270..c1012a1e0 100644 --- a/protocol/marshal.go +++ b/protocol/marshal.go @@ -30,7 +30,7 @@ type marshalWriter struct { // memory when reading a corrupt message. const maxBytesFieldLength = 10 * 1 << 20 -var ErrFieldLengthExceeded = errors.New("Raw bytes field size exceeds limit") +var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit") func (w *marshalWriter) writeString(s string) { w.writeBytes([]byte(s)) diff --git a/protocol/messages.go b/protocol/messages.go index 9a294609f..cfeb8c63f 100644 --- a/protocol/messages.go +++ b/protocol/messages.go @@ -1,8 +1,22 @@ package protocol -import "io" +import ( + "errors" + "io" +) + +const ( + maxNumFiles = 100000 // More than 100000 files is a protocol error + maxNumBlocks = 100000 // 100000 * 128KB = 12.5 GB max acceptable file size +) + +var ( + ErrMaxFilesExceeded = errors.New("Protocol error: number of files per index exceeds limit") + ErrMaxBlocksExceeded = errors.New("Protocol error: number of blocks per file exceeds limit") +) type request struct { + repo string name string offset int64 size uint32 @@ -33,7 +47,8 @@ func (w *marshalWriter) writeHeader(h header) { w.writeUint32(encodeHeader(h)) } -func (w *marshalWriter) writeIndex(idx []FileInfo) { +func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) { + w.writeString(repo) w.writeUint32(uint32(len(idx))) for _, f := range idx { w.writeString(f.Name) @@ -48,13 +63,14 @@ func (w *marshalWriter) writeIndex(idx []FileInfo) { } } -func WriteIndex(w io.Writer, idx []FileInfo) (int, error) { +func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) { mw := marshalWriter{w: w} - mw.writeIndex(idx) + mw.writeIndex(repo, idx) return int(mw.getTot()), mw.err } func (w *marshalWriter) writeRequest(r request) { + w.writeString(r.repo) w.writeString(r.name) w.writeUint64(uint64(r.offset)) w.writeUint32(r.size) @@ -77,9 +93,14 @@ func (r *marshalReader) readHeader() header { return decodeHeader(r.readUint32()) } -func (r *marshalReader) readIndex() []FileInfo { +func (r *marshalReader) readIndex() (string, []FileInfo) { var files []FileInfo + repo := r.readString() nfiles := r.readUint32() + if nfiles > maxNumFiles { + r.err = ErrMaxFilesExceeded + return "", nil + } if nfiles > 0 { files = make([]FileInfo, nfiles) for i := range files { @@ -88,6 +109,10 @@ func (r *marshalReader) readIndex() []FileInfo { files[i].Modified = int64(r.readUint64()) files[i].Version = r.readUint32() nblocks := r.readUint32() + if nblocks > maxNumBlocks { + r.err = ErrMaxBlocksExceeded + return "", nil + } blocks := make([]BlockInfo, nblocks) for j := range blocks { blocks[j].Size = r.readUint32() @@ -96,17 +121,18 @@ func (r *marshalReader) readIndex() []FileInfo { files[i].Blocks = blocks } } - return files + return repo, files } -func ReadIndex(r io.Reader) ([]FileInfo, error) { +func ReadIndex(r io.Reader) (string, []FileInfo, error) { mr := marshalReader{r: r} - idx := mr.readIndex() - return idx, mr.err + repo, idx := mr.readIndex() + return repo, idx, mr.err } func (r *marshalReader) readRequest() request { var req request + req.repo = r.readString() req.name = r.readString() req.offset = int64(r.readUint64()) req.size = r.readUint32() diff --git a/protocol/messages_test.go b/protocol/messages_test.go index 3be31b7ed..de3c90755 100644 --- a/protocol/messages_test.go +++ b/protocol/messages_test.go @@ -35,10 +35,14 @@ func TestIndex(t *testing.T) { var buf = new(bytes.Buffer) var wr = marshalWriter{w: buf} - wr.writeIndex(idx) + wr.writeIndex("default", idx) var rd = marshalReader{r: buf} - var idx2 = rd.readIndex() + var repo, idx2 = rd.readIndex() + + if repo != "default" { + t.Error("Incorrect repo", repo) + } if !reflect.DeepEqual(idx, idx2) { t.Errorf("Index marshal error:\n%#v\n%#v\n", idx, idx2) @@ -46,9 +50,9 @@ func TestIndex(t *testing.T) { } func TestRequest(t *testing.T) { - f := func(name string, offset int64, size uint32, hash []byte) bool { + f := func(repo, name string, offset int64, size uint32, hash []byte) bool { var buf = new(bytes.Buffer) - var req = request{name, offset, size, hash} + var req = request{repo, name, offset, size, hash} var wr = marshalWriter{w: buf} wr.writeRequest(req) var rd = marshalReader{r: buf} @@ -105,12 +109,12 @@ func BenchmarkWriteIndex(b *testing.B) { var wr = marshalWriter{w: ioutil.Discard} for i := 0; i < b.N; i++ { - wr.writeIndex(idx) + wr.writeIndex("default", idx) } } func BenchmarkWriteRequest(b *testing.B) { - var req = request{"blah blah", 1231323, 13123123, []byte("hash hash hash")} + var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")} var wr = marshalWriter{w: ioutil.Discard} for i := 0; i < b.N; i++ { diff --git a/protocol/protocol.go b/protocol/protocol.go index 77d97f5dc..62ce8b9c1 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -50,7 +50,7 @@ type Model interface { // An index update was received from the peer node IndexUpdate(nodeID string, files []FileInfo) // A request was made by the peer node - Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error) + Request(nodeID, repo string, name string, offset int64, size uint32, hash []byte) ([]byte, error) // The peer node closed the connection Close(nodeID string, err error) } @@ -67,7 +67,7 @@ type Connection struct { closed bool awaiting map[int]chan asyncResult nextId int - indexSent map[string][2]int64 + indexSent map[string]map[string][2]int64 peerOptions map[string]string myOptions map[string]string optionsLock sync.Mutex @@ -98,13 +98,14 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M } c := Connection{ - id: nodeID, - receiver: receiver, - reader: flrd, - mreader: &marshalReader{r: flrd}, - writer: flwr, - mwriter: &marshalWriter{w: flwr}, - awaiting: make(map[int]chan asyncResult), + id: nodeID, + receiver: receiver, + reader: flrd, + mreader: &marshalReader{r: flrd}, + writer: flwr, + mwriter: &marshalWriter{w: flwr}, + awaiting: make(map[int]chan asyncResult), + indexSent: make(map[string]map[string][2]int64), } go c.readerLoop() @@ -133,32 +134,32 @@ func (c *Connection) ID() string { } // Index writes the list of file information to the connected peer node -func (c *Connection) Index(idx []FileInfo) { +func (c *Connection) Index(repo string, idx []FileInfo) { c.Lock() var msgType int - if c.indexSent == nil { + if c.indexSent[repo] == nil { // This is the first time we send an index. msgType = messageTypeIndex - c.indexSent = make(map[string][2]int64) + c.indexSent[repo] = make(map[string][2]int64) for _, f := range idx { - c.indexSent[f.Name] = [2]int64{f.Modified, int64(f.Version)} + c.indexSent[repo][f.Name] = [2]int64{f.Modified, int64(f.Version)} } } else { // We have sent one full index. Only send updates now. msgType = messageTypeIndexUpdate var diff []FileInfo for _, f := range idx { - if vs, ok := c.indexSent[f.Name]; !ok || f.Modified != vs[0] || int64(f.Version) != vs[1] { + if vs, ok := c.indexSent[repo][f.Name]; !ok || f.Modified != vs[0] || int64(f.Version) != vs[1] { diff = append(diff, f) - c.indexSent[f.Name] = [2]int64{f.Modified, int64(f.Version)} + c.indexSent[repo][f.Name] = [2]int64{f.Modified, int64(f.Version)} } } idx = diff } c.mwriter.writeHeader(header{0, c.nextId, msgType}) - c.mwriter.writeIndex(idx) + c.mwriter.writeIndex(repo, idx) err := c.flush() c.nextId = (c.nextId + 1) & 0xfff c.hasSentIndex = true @@ -174,7 +175,7 @@ func (c *Connection) Index(idx []FileInfo) { } // Request returns the bytes for the specified block after fetching them from the connected peer. -func (c *Connection) Request(name string, offset int64, size uint32, hash []byte) ([]byte, error) { +func (c *Connection) Request(repo string, name string, offset int64, size uint32, hash []byte) ([]byte, error) { c.Lock() if c.closed { c.Unlock() @@ -183,7 +184,7 @@ func (c *Connection) Request(name string, offset int64, size uint32, hash []byte rc := make(chan asyncResult) c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) - c.mwriter.writeRequest(request{name, offset, size, hash}) + c.mwriter.writeRequest(request{repo, name, offset, size, hash}) if c.mwriter.err != nil { c.Unlock() c.close(c.mwriter.err) @@ -279,7 +280,8 @@ loop: switch hdr.msgType { case messageTypeIndex: - files := c.mreader.readIndex() + repo, files := c.mreader.readIndex() + _ = repo if c.mreader.err != nil { c.close(c.mreader.err) break loop @@ -291,7 +293,8 @@ loop: c.Unlock() case messageTypeIndexUpdate: - files := c.mreader.readIndex() + repo, files := c.mreader.readIndex() + _ = repo if c.mreader.err != nil { c.close(c.mreader.err) break loop @@ -370,7 +373,7 @@ loop: } func (c *Connection) processRequest(msgID int, req request) { - data, _ := c.receiver.Request(c.id, req.name, req.offset, req.size, req.hash) + data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash) c.Lock() c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse})) diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index b9434588a..08e41c49e 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -84,8 +84,8 @@ func TestRequestResponseErr(t *testing.T) { e := errors.New("Something broke") var pass bool - for i := 0; i < 36; i++ { - for j := 0; j < 26; j++ { + for i := 0; i < 48; i++ { + for j := 0; j < 38; j++ { m0 := &TestModel{data: []byte("response data")} m1 := &TestModel{} @@ -97,7 +97,7 @@ func TestRequestResponseErr(t *testing.T) { NewConnection("c0", ar, ebw, m0, nil) c1 := NewConnection("c1", br, eaw, m1, nil) - d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes")) + d, err := c1.Request("default", "tn", 1234, 3456, []byte("hashbytes")) if err == e || err == ErrClosed { t.Logf("Error at %d+%d bytes", i, j) if !m1.closed { @@ -115,6 +115,9 @@ func TestRequestResponseErr(t *testing.T) { if string(d) != "response data" { t.Errorf("Incorrect response data %q", string(d)) } + if m0.repo != "default" { + t.Error("Incorrect repo %q", m0.repo) + } if m0.name != "tn" { t.Error("Incorrect name %q", m0.name) } @@ -204,10 +207,10 @@ func TestClose(t *testing.T) { t.Error("Ping should not return true") } - c0.Index(nil) - c0.Index(nil) + c0.Index("default", nil) + c0.Index("default", nil) - _, err := c0.Request("foo", 0, 0, nil) + _, err := c0.Request("default", "foo", 0, 0, nil) if err == nil { t.Error("Request should return an error") }