diff --git a/cmd/relaysrv/main.go b/cmd/relaysrv/main.go index 5ca060689..b429d94a3 100644 --- a/cmd/relaysrv/main.go +++ b/cmd/relaysrv/main.go @@ -10,6 +10,7 @@ import ( "path/filepath" "time" + "github.com/juju/ratelimit" "github.com/syncthing/relaysrv/protocol" syncthingprotocol "github.com/syncthing/protocol" @@ -26,6 +27,11 @@ var ( networkTimeout time.Duration pingInterval time.Duration messageTimeout time.Duration + + sessionLimitBps int + globalLimitBps int + sessionLimiter *ratelimit.Bucket + globalLimiter *ratelimit.Bucket ) func main() { @@ -38,6 +44,11 @@ func main() { flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s") + flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") + flag.BoolVar(&debug, "debug", false, "Enable debug output") + + flag.Parse() if extAddress == "" { extAddress = listenSession @@ -51,10 +62,6 @@ func main() { sessionAddress = addr.IP[:] sessionPort = uint16(addr.Port) - flag.BoolVar(&debug, "debug", false, "Enable debug output") - - flag.Parse() - certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -83,6 +90,13 @@ func main() { log.Println("ID:", id) } + if sessionLimitBps > 0 { + sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps)) + } + if globalLimitBps > 0 { + globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps)) + } + go sessionListener(listenSession) protocolListener(listenProtocol, tlsCfg) diff --git a/cmd/relaysrv/protocol_listener.go b/cmd/relaysrv/protocol_listener.go index c3321aa50..8825af827 100644 --- a/cmd/relaysrv/protocol_listener.go +++ b/cmd/relaysrv/protocol_listener.go @@ -130,7 +130,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { continue } - ses := newSession() + ses := newSession(sessionLimiter, globalLimiter) go ses.Serve() diff --git a/cmd/relaysrv/session.go b/cmd/relaysrv/session.go index c5a091952..c526ed5d3 100644 --- a/cmd/relaysrv/session.go +++ b/cmd/relaysrv/session.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/juju/ratelimit" "github.com/syncthing/relaysrv/protocol" syncthingprotocol "github.com/syncthing/protocol" @@ -25,10 +26,12 @@ type session struct { serverkey []byte clientkey []byte + rateLimit func(bytes int64) + conns chan net.Conn } -func newSession() *session { +func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { serverkey := make([]byte, 32) _, err := rand.Read(serverkey) if err != nil { @@ -44,6 +47,7 @@ func newSession() *session { ses := &session{ serverkey: serverkey, clientkey: clientkey, + rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit), conns: make(chan net.Conn), } @@ -112,12 +116,12 @@ func (s *session) Serve() { errors := make(chan error, 2) go func() { - errors <- proxy(conns[0], conns[1]) + errors <- s.proxy(conns[0], conns[1]) wg.Done() }() go func() { - errors <- proxy(conns[1], conns[0]) + errors <- s.proxy(conns[1], conns[0]) wg.Done() }() @@ -169,14 +173,15 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr } } -func proxy(c1, c2 net.Conn) error { +func (s *session) proxy(c1, c2 net.Conn) error { if debug { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) } - buf := make([]byte, 1024) + + buf := make([]byte, 65536) for { c1.SetReadDeadline(time.Now().Add(networkTimeout)) - n, err := c1.Read(buf[0:]) + n, err := c1.Read(buf) if err != nil { return err } @@ -185,6 +190,10 @@ func proxy(c1, c2 net.Conn) error { log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) } + if s.rateLimit != nil { + s.rateLimit(int64(n)) + } + c2.SetWriteDeadline(time.Now().Add(networkTimeout)) _, err = c2.Write(buf[:n]) if err != nil { @@ -196,3 +205,45 @@ func proxy(c1, c2 net.Conn) error { func (s *session) String() string { return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5]) } + +func makeRateLimitFunc(sessionRateLimit, globalRateLimit *ratelimit.Bucket) func(int64) { + // This may be a case of super duper premature optimization... We build an + // optimized function to do the rate limiting here based on what we need + // to do and then use it in the loop. + + if sessionRateLimit == nil && globalRateLimit == nil { + // No limiting needed. We could equally well return a func(int64){} and + // not do a nil check were we use it, but I think the nil check there + // makes it clear that there will be no limiting if none is + // configured... + return nil + } + + if sessionRateLimit == nil { + // We only have a global limiter + return func(bytes int64) { + globalRateLimit.Wait(bytes) + } + } + + if globalRateLimit == nil { + // We only have a session limiter + return func(bytes int64) { + sessionRateLimit.Wait(bytes) + } + } + + // We have both. Queue the bytes on both the global and session specific + // rate limiters. Wait for both in parallell, so that the actual send + // happens when both conditions are satisfied. In practice this just means + // wait the longer of the two times. + return func(bytes int64) { + t0 := sessionRateLimit.Take(bytes) + t1 := globalRateLimit.Take(bytes) + if t0 > t1 { + time.Sleep(t0) + } else { + time.Sleep(t1) + } + } +}