From f9bd59f031caa6f246a91c5a919e2e7313962f60 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 11:38:00 +0200 Subject: [PATCH 1/2] Style and minor fixes, main package --- protocol_listener.go | 47 +++++++++++++++++++++++++++++++------------- session_listener.go | 29 +++++++++++++++++---------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/protocol_listener.go b/protocol_listener.go index 8825af827..a7243ff69 100644 --- a/protocol_listener.go +++ b/protocol_listener.go @@ -27,7 +27,6 @@ func protocolListener(addr string, config *tls.Config) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -35,6 +34,8 @@ func protocolListener(addr string, config *tls.Config) { continue } + setTCPOptions(conn) + if debug { log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) } @@ -74,16 +75,12 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { errors := make(chan error, 1) outbox := make(chan interface{}) - go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { - for { - msg, err := protocol.ReadMessage(conn) - if err != nil { - errors <- err - return - } - messages <- msg - } - }(conn, messages, errors) + // Read messages from the connection and send them on the messages + // channel. When there is an error, send it on the error channel and + // return. Applies also when the connection gets closed, so the pattern + // below is to close the connection on error, then wait for the error + // signal from messageReader to exit. + go messageReader(conn, messages, errors) pingTicker := time.NewTicker(pingInterval) timeoutTicker := time.NewTimer(networkTimeout) @@ -96,6 +93,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { if debug { log.Printf("Message %T from %s", message, id) } + switch msg := message.(type) { case protocol.JoinRelayRequest: outboxesMut.RLock() @@ -116,6 +114,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { joined = true protocol.WriteMessage(conn, protocol.ResponseSuccess) + case protocol.ConnectRequest: requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) outboxesMut.RLock() @@ -151,7 +150,10 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Println("Sent invitation from", id, "to", requestedPeer) } conn.Close() + case protocol.Pong: + // Nothing + default: if debug { log.Printf("Unknown message %s: %T", id, message) @@ -159,21 +161,25 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) conn.Close() } + case err := <-errors: if debug { log.Printf("Closing connection %s: %s", id, err) } - // Potentially closing a second time. close(outbox) + + // Potentially closing a second time. conn.Close() - // Only delete the outbox if the client join, as it migth be a - // lookup request coming from the same client. + + // Only delete the outbox if the client is joined, as it might be + // a lookup request coming from the same client. if joined { outboxesMut.Lock() delete(outboxes, id) outboxesMut.Unlock() } return + case <-pingTicker.C: if !joined { if debug { @@ -189,6 +195,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } conn.Close() } + case <-timeoutTicker.C: // We should receive a error from the reader loop, which will cause // us to quit this loop. @@ -196,6 +203,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Printf("%s timed out", id) } conn.Close() + case msg := <-outbox: if debug { log.Printf("Sending message %T to %s", msg, id) @@ -209,3 +217,14 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } } } + +func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } +} diff --git a/session_listener.go b/session_listener.go index 6159ceef5..2f6bae9ab 100644 --- a/session_listener.go +++ b/session_listener.go @@ -18,7 +18,6 @@ func sessionListener(addr string) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -26,6 +25,8 @@ func sessionListener(addr string) { continue } + setTCPOptions(conn) + if debug { log.Println("Session listener accepted connection from", conn.RemoteAddr()) } @@ -35,10 +36,17 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - conn.SetDeadline(time.Now().Add(messageTimeout)) + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } + message, err := protocol.ReadMessage(conn) if err != nil { - conn.Close() return } @@ -51,7 +59,6 @@ func sessionConnectionHandler(conn net.Conn) { if ses == nil { protocol.WriteMessage(conn, protocol.ResponseNotFound) - conn.Close() return } @@ -60,24 +67,26 @@ func sessionConnectionHandler(conn net.Conn) { log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) } protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) - conn.Close() return } - err := protocol.WriteMessage(conn, protocol.ResponseSuccess) - if err != nil { + if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { if debug { log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) } - conn.Close() return } - conn.SetDeadline(time.Time{}) + + if err := conn.SetDeadline(time.Time{}); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } default: if debug { log.Println("Unexpected message from", conn.RemoteAddr(), message) } protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) - conn.Close() } } From 049d92b52581015a759593e64314fcbfa2462ad2 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 20 Jul 2015 11:56:10 +0200 Subject: [PATCH 2/2] Style and minor fixes, client package --- client/client.go | 282 ++++++++++++++++++++++++----------------------- 1 file changed, 143 insertions(+), 139 deletions(-) diff --git a/client/client.go b/client/client.go index d05944aca..7169e6a8b 100644 --- a/client/client.go +++ b/client/client.go @@ -14,27 +14,6 @@ import ( "github.com/syncthing/relaysrv/protocol" ) -func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient { - closeInvitationsOnFinish := false - if invitations == nil { - closeInvitationsOnFinish = true - invitations = make(chan protocol.SessionInvitation) - } - return ProtocolClient{ - URI: uri, - Invitations: invitations, - - closeInvitationsOnFinish: closeInvitationsOnFinish, - - config: configForCerts(certs), - - timeout: time.Minute * 2, - - stop: make(chan struct{}), - stopped: make(chan struct{}), - } -} - type ProtocolClient struct { URI *url.URL Invitations chan protocol.SessionInvitation @@ -51,6 +30,129 @@ type ProtocolClient struct { conn *tls.Conn } +func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient { + closeInvitationsOnFinish := false + if invitations == nil { + closeInvitationsOnFinish = true + invitations = make(chan protocol.SessionInvitation) + } + + return &ProtocolClient{ + URI: uri, + Invitations: invitations, + + closeInvitationsOnFinish: closeInvitationsOnFinish, + + config: configForCerts(certs), + + timeout: time.Minute * 2, + + stop: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +func (c *ProtocolClient) Serve() { + c.stop = make(chan struct{}) + c.stopped = make(chan struct{}) + defer close(c.stopped) + + if err := c.connect(); err != nil { + l.Infoln("Relay connect:", err) + return + } + + if debug { + l.Debugln(c, "connected", c.conn.RemoteAddr()) + } + + if err := c.join(); err != nil { + c.conn.Close() + l.Infoln("Relay join:", err) + return + } + + if err := c.conn.SetDeadline(time.Time{}); err != nil { + l.Infoln("Relay set deadline:", err) + return + } + + if debug { + l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) + } + + defer c.cleanup() + + messages := make(chan interface{}) + errors := make(chan error, 1) + + go messageReader(c.conn, messages, errors) + + timeout := time.NewTimer(c.timeout) + + for { + select { + case message := <-messages: + timeout.Reset(c.timeout) + if debug { + log.Printf("%s received message %T", c, message) + } + + switch msg := message.(type) { + case protocol.Ping: + if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { + l.Infoln("Relay write:", err) + return + + } + if debug { + l.Debugln(c, "sent pong") + } + + case protocol.SessionInvitation: + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + c.Invitations <- msg + + default: + l.Infoln("Relay: protocol error: unexpected message %v", msg) + return + } + + case <-c.stop: + if debug { + l.Debugln(c, "stopping") + } + return + + case err := <-errors: + l.Infoln("Relay received:", err) + return + + case <-timeout.C: + if debug { + l.Debugln(c, "timed out") + } + return + } + } +} + +func (c *ProtocolClient) Stop() { + if c.stop == nil { + return + } + + close(c.stop) + <-c.stopped +} + +func (c *ProtocolClient) String() string { + return fmt.Sprintf("ProtocolClient@%p", c) +} + func (c *ProtocolClient) connect() error { if c.URI.Scheme != "relay" { return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme) @@ -61,9 +163,13 @@ func (c *ProtocolClient) connect() error { return err } - conn.SetDeadline(time.Now().Add(10 * time.Second)) + if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { + conn.Close() + return err + } if err := performHandshakeAndValidation(conn, c.URI); err != nil { + conn.Close() return err } @@ -71,101 +177,6 @@ func (c *ProtocolClient) connect() error { return nil } -func (c *ProtocolClient) Serve() { - if err := c.connect(); err != nil { - panic(err) - } - - if debug { - l.Debugln(c, "connected", c.conn.RemoteAddr()) - } - - if err := c.join(); err != nil { - c.conn.Close() - panic(err) - } - - c.conn.SetDeadline(time.Time{}) - - if debug { - l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) - } - - c.stop = make(chan struct{}) - c.stopped = make(chan struct{}) - - defer c.cleanup() - - messages := make(chan interface{}) - errors := make(chan error, 1) - - go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { - for { - msg, err := protocol.ReadMessage(conn) - if err != nil { - errors <- err - return - } - messages <- msg - } - }(c.conn, messages, errors) - - timeout := time.NewTimer(c.timeout) - for { - select { - case message := <-messages: - timeout.Reset(c.timeout) - if debug { - log.Printf("%s received message %T", c, message) - } - switch msg := message.(type) { - case protocol.Ping: - if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { - panic(err) - } - if debug { - l.Debugln(c, "sent pong") - } - case protocol.SessionInvitation: - ip := net.IP(msg.Address) - if len(ip) == 0 || ip.IsUnspecified() { - msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] - } - c.Invitations <- msg - default: - panic(fmt.Errorf("protocol error: unexpected message %v", msg)) - } - case <-c.stop: - if debug { - l.Debugln(c, "stopping") - } - break - case err := <-errors: - panic(err) - case <-timeout.C: - if debug { - l.Debugln(c, "timed out") - } - return - } - } - - c.stopped <- struct{}{} -} - -func (c *ProtocolClient) Stop() { - if c.stop == nil { - return - } - - c.stop <- struct{}{} - <-c.stopped -} - -func (c *ProtocolClient) String() string { - return fmt.Sprintf("ProtocolClient@%p", c) -} - func (c *ProtocolClient) cleanup() { if c.closeInvitationsOnFinish { close(c.Invitations) @@ -176,24 +187,11 @@ func (c *ProtocolClient) cleanup() { l.Debugln(c, "cleaning up") } - if c.stop != nil { - close(c.stop) - c.stop = nil - } - - if c.stopped != nil { - close(c.stopped) - c.stopped = nil - } - - if c.conn != nil { - c.conn.Close() - } + c.conn.Close() } func (c *ProtocolClient) join() error { - err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}) - if err != nil { + if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil { return err } @@ -207,6 +205,7 @@ func (c *ProtocolClient) join() error { if msg.Code != 0 { return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) } + default: return fmt.Errorf("protocol error: expecting response got %v", msg) } @@ -215,15 +214,12 @@ func (c *ProtocolClient) join() error { } func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { - err := conn.Handshake() - if err != nil { - conn.Close() + if err := conn.Handshake(); err != nil { return err } cs := conn.ConnectionState() if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName { - conn.Close() return fmt.Errorf("protocol negotiation error") } @@ -232,22 +228,30 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { if relayIDs != "" { relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs) if err != nil { - conn.Close() return fmt.Errorf("relay address contains invalid verification id: %s", err) } certs := cs.PeerCertificates if cl := len(certs); cl != 1 { - conn.Close() return fmt.Errorf("unexpected certificate count: %d", cl) } remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw) if remoteID != relayID { - conn.Close() return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID) } } return nil } + +func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } +}