From 108b4e2e104610bdf416f2f156f35ee769276caf Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 23 Feb 2015 09:30:47 +0100 Subject: [PATCH] Add more fine grained compression control --- compression.go | 53 +++++++++++++++++++++++++++++++++++++++++++++ compression_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++ protocol.go | 38 +++++++++++++++++--------------- protocol_test.go | 20 ++++++++--------- 4 files changed, 135 insertions(+), 27 deletions(-) create mode 100644 compression.go create mode 100644 compression_test.go diff --git a/compression.go b/compression.go new file mode 100644 index 000000000..9e17213b6 --- /dev/null +++ b/compression.go @@ -0,0 +1,53 @@ +// Copyright (C) 2015 The Protocol Authors. + +package protocol + +import "fmt" + +type Compression int + +const ( + CompressMetadata Compression = iota // zero value is the default, default should be "metadata" + CompressNever + CompressAlways + + compressionThreshold = 128 // don't bother compressing messages smaller than this many bytes +) + +var compressionMarshal = map[Compression]string{ + CompressNever: "never", + CompressMetadata: "metadata", + CompressAlways: "always", +} + +var compressionUnmarshal = map[string]Compression{ + // Legacy + "false": CompressNever, + "true": CompressMetadata, + + // Current + "never": CompressNever, + "metadata": CompressMetadata, + "always": CompressAlways, +} + +func (c Compression) String() string { + s, ok := compressionMarshal[c] + if !ok { + return fmt.Sprintf("unknown:%d", c) + } + return s +} + +func (c Compression) GoString() string { + return fmt.Sprintf("%q", c.String()) +} + +func (c Compression) MarshalText() ([]byte, error) { + return []byte(compressionMarshal[c]), nil +} + +func (c *Compression) UnmarshalText(bs []byte) error { + *c = compressionUnmarshal[string(bs)] + return nil +} diff --git a/compression_test.go b/compression_test.go new file mode 100644 index 000000000..932297c32 --- /dev/null +++ b/compression_test.go @@ -0,0 +1,51 @@ +// Copyright (C) 2015 The Protocol Authors. + +package protocol + +import "testing" + +func TestCompressionMarshal(t *testing.T) { + uTestcases := []struct { + s string + c Compression + }{ + {"true", CompressMetadata}, + {"false", CompressNever}, + {"never", CompressNever}, + {"metadata", CompressMetadata}, + {"filedata", CompressFiledata}, + {"always", CompressAlways}, + {"whatever", CompressNever}, + } + + mTestcases := []struct { + s string + c Compression + }{ + {"never", CompressNever}, + {"metadata", CompressMetadata}, + {"filedata", CompressFiledata}, + {"always", CompressAlways}, + } + + var c Compression + for _, tc := range uTestcases { + err := c.UnmarshalText([]byte(tc.s)) + if err != nil { + t.Error(err) + } + if c != tc.c { + t.Errorf("%s unmarshalled to %d, not %d", tc.s, c, tc.c) + } + } + + for _, tc := range mTestcases { + bs, err := tc.c.MarshalText() + if err != nil { + t.Error(err) + } + if s := string(bs); s != tc.s { + t.Errorf("%d marshalled to %q, not %q", tc.c, s, tc.s) + } + } +} diff --git a/protocol.go b/protocol.go index f5c34dbff..e7b6fe275 100644 --- a/protocol.go +++ b/protocol.go @@ -106,7 +106,7 @@ type rawConnection struct { closed chan struct{} once sync.Once - compressionThreshold int // compress messages larger than this many bytes + compression Compression rdbuf0 []byte // used & reused by readMessage rdbuf1 []byte // used & reused by readMessage @@ -135,25 +135,21 @@ const ( pingIdleTime = 60 * time.Second ) -func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress bool) Connection { +func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { cr := &countingReader{Reader: reader} cw := &countingWriter{Writer: writer} - compThres := 1<<31 - 1 // compression disabled - if compress { - compThres = 128 // compress messages that are 128 bytes long or larger - } c := rawConnection{ - id: deviceID, - name: name, - receiver: nativeModel{receiver}, - state: stateInitial, - cr: cr, - cw: cw, - outbox: make(chan hdrMsg), - nextID: make(chan int), - closed: make(chan struct{}), - compressionThreshold: compThres, + id: deviceID, + name: name, + receiver: nativeModel{receiver}, + state: stateInitial, + cr: cr, + cw: cw, + outbox: make(chan hdrMsg), + nextID: make(chan int), + closed: make(chan struct{}), + compression: compress, } go c.readerLoop() @@ -571,7 +567,15 @@ func (c *rawConnection) writerLoop() { return } - if len(uncBuf) >= c.compressionThreshold { + compress := false + switch c.compression { + case CompressAlways: + compress = true + case CompressMetadata: + compress = hm.hdr.msgType != messageTypeResponse + } + + if compress && len(uncBuf) >= compressionThreshold { // Use compression for large messages hm.hdr.compression = true diff --git a/protocol_test.go b/protocol_test.go index 1ccb4525f..c1048cdcf 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -67,8 +67,8 @@ func TestPing(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, nil, "name", true).(wireFormatConnection).next.(*rawConnection) - c1 := NewConnection(c1ID, br, aw, nil, "name", true).(wireFormatConnection).next.(*rawConnection) + c0 := NewConnection(c0ID, ar, bw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + c1 := NewConnection(c1ID, br, aw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) if ok := c0.ping(); !ok { t.Error("c0 ping failed") @@ -91,8 +91,8 @@ func TestPingErr(t *testing.T) { eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} - c0 := NewConnection(c0ID, ar, ebw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, eaw, m1, "name", true) + c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + NewConnection(c1ID, br, eaw, m1, "name", CompressAlways) res := c0.ping() if (i < 8 || j < 8) && res { @@ -167,8 +167,8 @@ func TestVersionErr(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", true) + c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + NewConnection(c1ID, br, aw, m1, "name", CompressAlways) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ @@ -190,8 +190,8 @@ func TestTypeErr(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", true) + c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + NewConnection(c1ID, br, aw, m1, "name", CompressAlways) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ @@ -213,8 +213,8 @@ func TestClose(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() - c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) - NewConnection(c1ID, br, aw, m1, "name", true) + c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) + NewConnection(c1ID, br, aw, m1, "name", CompressAlways) c0.close(nil)