Add more fine grained compression control

This commit is contained in:
Jakob Borg 2015-02-23 09:30:47 +01:00
parent cd0cce4195
commit 108b4e2e10
4 changed files with 135 additions and 27 deletions

53
compression.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright (C) 2015 The Protocol Authors.
package protocol
import "fmt"
type Compression int
const (
CompressMetadata Compression = iota // zero value is the default, default should be "metadata"
CompressNever
CompressAlways
compressionThreshold = 128 // don't bother compressing messages smaller than this many bytes
)
var compressionMarshal = map[Compression]string{
CompressNever: "never",
CompressMetadata: "metadata",
CompressAlways: "always",
}
var compressionUnmarshal = map[string]Compression{
// Legacy
"false": CompressNever,
"true": CompressMetadata,
// Current
"never": CompressNever,
"metadata": CompressMetadata,
"always": CompressAlways,
}
func (c Compression) String() string {
s, ok := compressionMarshal[c]
if !ok {
return fmt.Sprintf("unknown:%d", c)
}
return s
}
func (c Compression) GoString() string {
return fmt.Sprintf("%q", c.String())
}
func (c Compression) MarshalText() ([]byte, error) {
return []byte(compressionMarshal[c]), nil
}
func (c *Compression) UnmarshalText(bs []byte) error {
*c = compressionUnmarshal[string(bs)]
return nil
}

51
compression_test.go Normal file
View File

@ -0,0 +1,51 @@
// Copyright (C) 2015 The Protocol Authors.
package protocol
import "testing"
func TestCompressionMarshal(t *testing.T) {
uTestcases := []struct {
s string
c Compression
}{
{"true", CompressMetadata},
{"false", CompressNever},
{"never", CompressNever},
{"metadata", CompressMetadata},
{"filedata", CompressFiledata},
{"always", CompressAlways},
{"whatever", CompressNever},
}
mTestcases := []struct {
s string
c Compression
}{
{"never", CompressNever},
{"metadata", CompressMetadata},
{"filedata", CompressFiledata},
{"always", CompressAlways},
}
var c Compression
for _, tc := range uTestcases {
err := c.UnmarshalText([]byte(tc.s))
if err != nil {
t.Error(err)
}
if c != tc.c {
t.Errorf("%s unmarshalled to %d, not %d", tc.s, c, tc.c)
}
}
for _, tc := range mTestcases {
bs, err := tc.c.MarshalText()
if err != nil {
t.Error(err)
}
if s := string(bs); s != tc.s {
t.Errorf("%d marshalled to %q, not %q", tc.c, s, tc.s)
}
}
}

View File

@ -106,7 +106,7 @@ type rawConnection struct {
closed chan struct{}
once sync.Once
compressionThreshold int // compress messages larger than this many bytes
compression Compression
rdbuf0 []byte // used & reused by readMessage
rdbuf1 []byte // used & reused by readMessage
@ -135,25 +135,21 @@ const (
pingIdleTime = 60 * time.Second
)
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress bool) Connection {
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
cr := &countingReader{Reader: reader}
cw := &countingWriter{Writer: writer}
compThres := 1<<31 - 1 // compression disabled
if compress {
compThres = 128 // compress messages that are 128 bytes long or larger
}
c := rawConnection{
id: deviceID,
name: name,
receiver: nativeModel{receiver},
state: stateInitial,
cr: cr,
cw: cw,
outbox: make(chan hdrMsg),
nextID: make(chan int),
closed: make(chan struct{}),
compressionThreshold: compThres,
id: deviceID,
name: name,
receiver: nativeModel{receiver},
state: stateInitial,
cr: cr,
cw: cw,
outbox: make(chan hdrMsg),
nextID: make(chan int),
closed: make(chan struct{}),
compression: compress,
}
go c.readerLoop()
@ -571,7 +567,15 @@ func (c *rawConnection) writerLoop() {
return
}
if len(uncBuf) >= c.compressionThreshold {
compress := false
switch c.compression {
case CompressAlways:
compress = true
case CompressMetadata:
compress = hm.hdr.msgType != messageTypeResponse
}
if compress && len(uncBuf) >= compressionThreshold {
// Use compression for large messages
hm.hdr.compression = true

View File

@ -67,8 +67,8 @@ func TestPing(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, nil, "name", true).(wireFormatConnection).next.(*rawConnection)
c1 := NewConnection(c1ID, br, aw, nil, "name", true).(wireFormatConnection).next.(*rawConnection)
c0 := NewConnection(c0ID, ar, bw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
c1 := NewConnection(c1ID, br, aw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
if ok := c0.ping(); !ok {
t.Error("c0 ping failed")
@ -91,8 +91,8 @@ func TestPingErr(t *testing.T) {
eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
c0 := NewConnection(c0ID, ar, ebw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, eaw, m1, "name", true)
c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, eaw, m1, "name", CompressAlways)
res := c0.ping()
if (i < 8 || j < 8) && res {
@ -167,8 +167,8 @@ func TestVersionErr(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", true)
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
w := xdr.NewWriter(c0.cw)
w.WriteUint32(encodeHeader(header{
@ -190,8 +190,8 @@ func TestTypeErr(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", true)
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
w := xdr.NewWriter(c0.cw)
w.WriteUint32(encodeHeader(header{
@ -213,8 +213,8 @@ func TestClose(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", true)
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
c0.close(nil)