Implement global and per session rate limiting

This commit is contained in:
Jakob Borg 2015-07-20 13:25:08 +02:00
parent c318fdc94b
commit 98a13204b2
3 changed files with 76 additions and 11 deletions

View File

@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/juju/ratelimit"
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
syncthingprotocol "github.com/syncthing/protocol" syncthingprotocol "github.com/syncthing/protocol"
@ -26,6 +27,11 @@ var (
networkTimeout time.Duration networkTimeout time.Duration
pingInterval time.Duration pingInterval time.Duration
messageTimeout time.Duration messageTimeout time.Duration
sessionLimitBps int
globalLimitBps int
sessionLimiter *ratelimit.Bucket
globalLimiter *ratelimit.Bucket
) )
func main() { func main() {
@ -38,6 +44,11 @@ func main() {
flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive")
flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s")
flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s")
flag.BoolVar(&debug, "debug", false, "Enable debug output")
flag.Parse()
if extAddress == "" { if extAddress == "" {
extAddress = listenSession extAddress = listenSession
@ -51,10 +62,6 @@ func main() {
sessionAddress = addr.IP[:] sessionAddress = addr.IP[:]
sessionPort = uint16(addr.Port) sessionPort = uint16(addr.Port)
flag.BoolVar(&debug, "debug", false, "Enable debug output")
flag.Parse()
certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
@ -83,6 +90,13 @@ func main() {
log.Println("ID:", id) log.Println("ID:", id)
} }
if sessionLimitBps > 0 {
sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps))
}
if globalLimitBps > 0 {
globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps))
}
go sessionListener(listenSession) go sessionListener(listenSession)
protocolListener(listenProtocol, tlsCfg) protocolListener(listenProtocol, tlsCfg)

View File

@ -130,7 +130,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
continue continue
} }
ses := newSession() ses := newSession(sessionLimiter, globalLimiter)
go ses.Serve() go ses.Serve()

View File

@ -11,6 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/juju/ratelimit"
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
syncthingprotocol "github.com/syncthing/protocol" syncthingprotocol "github.com/syncthing/protocol"
@ -25,10 +26,12 @@ type session struct {
serverkey []byte serverkey []byte
clientkey []byte clientkey []byte
rateLimit func(bytes int64)
conns chan net.Conn conns chan net.Conn
} }
func newSession() *session { func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
serverkey := make([]byte, 32) serverkey := make([]byte, 32)
_, err := rand.Read(serverkey) _, err := rand.Read(serverkey)
if err != nil { if err != nil {
@ -44,6 +47,7 @@ func newSession() *session {
ses := &session{ ses := &session{
serverkey: serverkey, serverkey: serverkey,
clientkey: clientkey, clientkey: clientkey,
rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
conns: make(chan net.Conn), conns: make(chan net.Conn),
} }
@ -112,12 +116,12 @@ func (s *session) Serve() {
errors := make(chan error, 2) errors := make(chan error, 2)
go func() { go func() {
errors <- proxy(conns[0], conns[1]) errors <- s.proxy(conns[0], conns[1])
wg.Done() wg.Done()
}() }()
go func() { go func() {
errors <- proxy(conns[1], conns[0]) errors <- s.proxy(conns[1], conns[0])
wg.Done() wg.Done()
}() }()
@ -169,14 +173,15 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr
} }
} }
func proxy(c1, c2 net.Conn) error { func (s *session) proxy(c1, c2 net.Conn) error {
if debug { if debug {
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
} }
buf := make([]byte, 1024)
buf := make([]byte, 65536)
for { for {
c1.SetReadDeadline(time.Now().Add(networkTimeout)) c1.SetReadDeadline(time.Now().Add(networkTimeout))
n, err := c1.Read(buf[0:]) n, err := c1.Read(buf)
if err != nil { if err != nil {
return err return err
} }
@ -185,6 +190,10 @@ func proxy(c1, c2 net.Conn) error {
log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
} }
if s.rateLimit != nil {
s.rateLimit(int64(n))
}
c2.SetWriteDeadline(time.Now().Add(networkTimeout)) c2.SetWriteDeadline(time.Now().Add(networkTimeout))
_, err = c2.Write(buf[:n]) _, err = c2.Write(buf[:n])
if err != nil { if err != nil {
@ -196,3 +205,45 @@ func proxy(c1, c2 net.Conn) error {
func (s *session) String() string { func (s *session) String() string {
return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5]) return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
} }
func makeRateLimitFunc(sessionRateLimit, globalRateLimit *ratelimit.Bucket) func(int64) {
// This may be a case of super duper premature optimization... We build an
// optimized function to do the rate limiting here based on what we need
// to do and then use it in the loop.
if sessionRateLimit == nil && globalRateLimit == nil {
// No limiting needed. We could equally well return a func(int64){} and
// not do a nil check were we use it, but I think the nil check there
// makes it clear that there will be no limiting if none is
// configured...
return nil
}
if sessionRateLimit == nil {
// We only have a global limiter
return func(bytes int64) {
globalRateLimit.Wait(bytes)
}
}
if globalRateLimit == nil {
// We only have a session limiter
return func(bytes int64) {
sessionRateLimit.Wait(bytes)
}
}
// We have both. Queue the bytes on both the global and session specific
// rate limiters. Wait for both in parallell, so that the actual send
// happens when both conditions are satisfied. In practice this just means
// wait the longer of the two times.
return func(bytes int64) {
t0 := sessionRateLimit.Take(bytes)
t1 := globalRateLimit.Take(bytes)
if t0 > t1 {
time.Sleep(t0)
} else {
time.Sleep(t1)
}
}
}