From f376c79f7f2c56568fbcad5d6062f2b9547b5887 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Wed, 24 Jun 2015 12:39:46 +0100 Subject: [PATCH] Add initial code --- cmd/relaysrv/main.go | 88 ++++++ cmd/relaysrv/protocol/packets.go | 45 +++ cmd/relaysrv/protocol/packets_xdr.go | 415 +++++++++++++++++++++++++++ cmd/relaysrv/protocol_listener.go | 230 +++++++++++++++ cmd/relaysrv/session.go | 173 +++++++++++ cmd/relaysrv/session_listener.go | 59 ++++ cmd/relaysrv/utils.go | 53 ++++ 7 files changed, 1063 insertions(+) create mode 100644 cmd/relaysrv/main.go create mode 100644 cmd/relaysrv/protocol/packets.go create mode 100644 cmd/relaysrv/protocol/packets_xdr.go create mode 100644 cmd/relaysrv/protocol_listener.go create mode 100644 cmd/relaysrv/session.go create mode 100644 cmd/relaysrv/session_listener.go create mode 100644 cmd/relaysrv/utils.go diff --git a/cmd/relaysrv/main.go b/cmd/relaysrv/main.go new file mode 100644 index 000000000..3c4d533ed --- /dev/null +++ b/cmd/relaysrv/main.go @@ -0,0 +1,88 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/tls" + "flag" + "log" + "os" + "path/filepath" + "sync" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +var ( + listenProtocol string + listenSession string + debug bool + + sessionAddress []byte + sessionPort uint16 + + networkTimeout time.Duration + pingInterval time.Duration + messageTimeout time.Duration + + pingMessage message + + mut = sync.RWMutex{} + outbox = make(map[syncthingprotocol.DeviceID]chan message) +) + +func main() { + var dir, extAddress string + + pingPayload := protocol.Ping{}.MustMarshalXDR() + pingMessage = message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypePing, + MessageLength: int32(len(pingPayload)), + }, + payload: pingPayload, + } + + flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") + flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") + flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") + flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") + flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations") + flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") + flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + + flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.Parse() + + certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } + + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{protocol.ProtocolName}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } + + log.SetOutput(os.Stdout) + + go sessionListener(listenSession) + + protocolListener(listenProtocol, tlsCfg) +} diff --git a/cmd/relaysrv/protocol/packets.go b/cmd/relaysrv/protocol/packets.go new file mode 100644 index 000000000..4675d1cf4 --- /dev/null +++ b/cmd/relaysrv/protocol/packets.go @@ -0,0 +1,45 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +//go:generate -command genxdr go run ../../syncthing/Godeps/_workspace/src/github.com/calmh/xdr/cmd/genxdr/main.go +//go:generate genxdr -o packets_xdr.go packets.go + +package protocol + +import ( + "unsafe" +) + +const ( + Magic = 0x9E79BC40 + HeaderSize = unsafe.Sizeof(&Header{}) + ProtocolName = "bep-relay" +) + +const ( + MessageTypePing int32 = iota + MessageTypePong + MessageTypeJoinRequest + MessageTypeConnectRequest + MessageTypeSessionInvitation +) + +type Header struct { + Magic uint32 + MessageType int32 + MessageLength int32 +} + +type Ping struct{} +type Pong struct{} +type JoinRequest struct{} + +type ConnectRequest struct { + ID []byte // max:32 +} + +type SessionInvitation struct { + Key []byte // max:32 + Address []byte // max:32 + Port uint16 + ServerSocket bool +} diff --git a/cmd/relaysrv/protocol/packets_xdr.go b/cmd/relaysrv/protocol/packets_xdr.go new file mode 100644 index 000000000..ca547e007 --- /dev/null +++ b/cmd/relaysrv/protocol/packets_xdr.go @@ -0,0 +1,415 @@ +// ************************************************************ +// This file is automatically generated by genxdr. Do not edit. +// ************************************************************ + +package protocol + +import ( + "bytes" + "io" + + "github.com/calmh/xdr" +) + +/* + +Header Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Magic | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Type | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Header { + unsigned int Magic; + int MessageType; + int MessageLength; +} + +*/ + +func (o Header) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Header) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Header) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Header) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Header) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(o.Magic) + xw.WriteUint32(uint32(o.MessageType)) + xw.WriteUint32(uint32(o.MessageLength)) + return xw.Tot(), xw.Error() +} + +func (o *Header) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Header) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Header) DecodeXDRFrom(xr *xdr.Reader) error { + o.Magic = xr.ReadUint32() + o.MessageType = int32(xr.ReadUint32()) + o.MessageLength = int32(xr.ReadUint32()) + return xr.Error() +} + +/* + +Ping Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Ping { +} + +*/ + +func (o Ping) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Ping) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Ping) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Ping) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Ping) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *Ping) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Ping) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Ping) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +Pong Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Pong { +} + +*/ + +func (o Pong) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Pong) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Pong) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Pong) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Pong) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *Pong) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Pong) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Pong) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +JoinRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct JoinRequest { +} + +*/ + +func (o JoinRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o JoinRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o JoinRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o JoinRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o JoinRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + return xw.Tot(), xw.Error() +} + +func (o *JoinRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinRequest) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +ConnectRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of ID | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ ID (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct ConnectRequest { + opaque ID<32>; +} + +*/ + +func (o ConnectRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o ConnectRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o ConnectRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o ConnectRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o ConnectRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.ID); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("ID", l, 32) + } + xw.WriteBytes(o.ID) + return xw.Tot(), xw.Error() +} + +func (o *ConnectRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *ConnectRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *ConnectRequest) DecodeXDRFrom(xr *xdr.Reader) error { + o.ID = xr.ReadBytesMax(32) + return xr.Error() +} + +/* + +SessionInvitation Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Key | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Key (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Address | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Address (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| 0x0000 | Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Server Socket (V=0 or 1) |V| ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct SessionInvitation { + opaque Key<32>; + opaque Address<32>; + unsigned int Port; + bool ServerSocket; +} + +*/ + +func (o SessionInvitation) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o SessionInvitation) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o SessionInvitation) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o SessionInvitation) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.Key); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) + } + xw.WriteBytes(o.Key) + if l := len(o.Address); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Address", l, 32) + } + xw.WriteBytes(o.Address) + xw.WriteUint16(o.Port) + xw.WriteBool(o.ServerSocket) + return xw.Tot(), xw.Error() +} + +func (o *SessionInvitation) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *SessionInvitation) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error { + o.Key = xr.ReadBytesMax(32) + o.Address = xr.ReadBytesMax(32) + o.Port = xr.ReadUint16() + o.ServerSocket = xr.ReadBool() + return xr.Error() +} diff --git a/cmd/relaysrv/protocol_listener.go b/cmd/relaysrv/protocol_listener.go new file mode 100644 index 000000000..b6d89b226 --- /dev/null +++ b/cmd/relaysrv/protocol_listener.go @@ -0,0 +1,230 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/tls" + "io" + "log" + "net" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + + "github.com/syncthing/relaysrv/protocol" +) + +type message struct { + header protocol.Header + payload []byte +} + +func protocolListener(addr string, config *tls.Config) { + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalln(err) + } + + for { + conn, err := listener.Accept() + if err != nil { + if debug { + log.Println(err) + } + continue + } + + if debug { + log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) + } + + go protocolConnectionHandler(conn, config) + } +} + +func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { + err := setTCPOptions(tcpConn) + if err != nil && debug { + log.Println("Failed to set TCP options on protocol connection", tcpConn.RemoteAddr(), err) + } + + conn := tls.Server(tcpConn, config) + err = conn.Handshake() + if err != nil { + log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + conn.Close() + return + } + + state := conn.ConnectionState() + if (!state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol != protocol.ProtocolName) && debug { + log.Println("Protocol negotiation error") + } + + certs := state.PeerCertificates + if len(certs) != 1 { + log.Println("Certificate list error") + conn.Close() + return + } + + deviceId := syncthingprotocol.NewDeviceID(certs[0].Raw) + + mut.RLock() + _, ok := outbox[deviceId] + mut.RUnlock() + if ok { + log.Println("Already have a peer with the same ID", deviceId, conn.RemoteAddr()) + conn.Close() + return + } + + errorChannel := make(chan error) + messageChannel := make(chan message) + outboxChannel := make(chan message) + + go readerLoop(conn, messageChannel, errorChannel) + + pingTicker := time.NewTicker(pingInterval) + timeoutTicker := time.NewTimer(messageTimeout * 2) + joined := false + + for { + select { + case msg := <-messageChannel: + switch msg.header.MessageType { + case protocol.MessageTypeJoinRequest: + mut.Lock() + outbox[deviceId] = outboxChannel + mut.Unlock() + joined = true + case protocol.MessageTypeConnectRequest: + // We will disconnect after this message, no matter what, + // because, we've either sent out an invitation, or we don't + // have the peer available. + var fmsg protocol.ConnectRequest + err := fmsg.UnmarshalXDR(msg.payload) + if err != nil { + log.Println(err) + conn.Close() + continue + } + + requestedPeer := syncthingprotocol.DeviceIDFromBytes(fmsg.ID) + mut.RLock() + peerOutbox, ok := outbox[requestedPeer] + mut.RUnlock() + if !ok { + if debug { + log.Println("Do not have", requestedPeer) + } + conn.Close() + continue + } + + ses := newSession() + + smsg, err := ses.GetServerInvitationMessage() + if err != nil { + log.Println("Error getting server invitation", requestedPeer) + conn.Close() + continue + } + cmsg, err := ses.GetClientInvitationMessage() + if err != nil { + log.Println("Error getting client invitation", requestedPeer) + conn.Close() + continue + } + + go ses.Serve() + + if err := sendMessage(cmsg, conn); err != nil { + log.Println("Failed to send invitation message", err) + } else { + peerOutbox <- smsg + if debug { + log.Println("Sent invitation from", deviceId, "to", requestedPeer) + } + } + conn.Close() + case protocol.MessageTypePong: + timeoutTicker.Reset(messageTimeout) + } + case err := <-errorChannel: + log.Println("Closing connection:", err) + return + case <-pingTicker.C: + if !joined { + log.Println(deviceId, "didn't join within", messageTimeout) + conn.Close() + continue + } + + if err := sendMessage(pingMessage, conn); err != nil { + log.Println(err) + conn.Close() + continue + } + case <-timeoutTicker.C: + // We should receive a error, which will cause us to quit the + // loop. + conn.Close() + case msg := <-outboxChannel: + if debug { + log.Println("Sending message to", deviceId, msg) + } + if err := sendMessage(msg, conn); err == nil { + log.Println(err) + conn.Close() + continue + } + } + } +} + +func readerLoop(conn *tls.Conn, messages chan<- message, errors chan<- error) { + header := make([]byte, protocol.HeaderSize) + data := make([]byte, 0, 0) + for { + _, err := io.ReadFull(conn, header) + if err != nil { + errors <- err + conn.Close() + return + } + + var hdr protocol.Header + err = hdr.UnmarshalXDR(header) + if err != nil { + conn.Close() + return + } + + if hdr.Magic != protocol.Magic { + conn.Close() + return + } + + if hdr.MessageLength > int32(cap(data)) { + data = make([]byte, 0, hdr.MessageLength) + } else { + data = data[:hdr.MessageLength] + } + + _, err = io.ReadFull(conn, data) + if err != nil { + errors <- err + conn.Close() + return + } + + msg := message{ + header: hdr, + payload: make([]byte, hdr.MessageLength), + } + copy(msg.payload, data[:hdr.MessageLength]) + + messages <- msg + } +} diff --git a/cmd/relaysrv/session.go b/cmd/relaysrv/session.go new file mode 100644 index 000000000..3466bd535 --- /dev/null +++ b/cmd/relaysrv/session.go @@ -0,0 +1,173 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/rand" + "net" + "sync" + "time" + + "github.com/syncthing/relaysrv/protocol" +) + +var ( + sessionmut = sync.Mutex{} + sessions = make(map[string]*session, 0) +) + +type session struct { + serverkey string + clientkey string + + mut sync.RWMutex + conns chan net.Conn +} + +func newSession() *session { + serverkey := make([]byte, 32) + _, err := rand.Read(serverkey) + if err != nil { + return nil + } + + clientkey := make([]byte, 32) + _, err = rand.Read(clientkey) + if err != nil { + return nil + } + + return &session{ + serverkey: string(serverkey), + clientkey: string(clientkey), + conns: make(chan net.Conn), + } +} + +func findSession(key string) *session { + sessionmut.Lock() + defer sessionmut.Unlock() + lob, ok := sessions[key] + if !ok { + return nil + + } + delete(sessions, key) + return lob +} + +func (l *session) AddConnection(conn net.Conn) { + select { + case l.conns <- conn: + default: + } +} + +func (l *session) Serve() { + + timedout := time.After(messageTimeout) + + sessionmut.Lock() + sessions[l.serverkey] = l + sessions[l.clientkey] = l + sessionmut.Unlock() + + conns := make([]net.Conn, 0, 2) + for { + select { + case conn := <-l.conns: + conns = append(conns, conn) + if len(conns) < 2 { + continue + } + + close(l.conns) + + wg := sync.WaitGroup{} + + wg.Add(2) + + go proxy(conns[0], conns[1], wg) + go proxy(conns[1], conns[0], wg) + + wg.Wait() + + break + case <-timedout: + sessionmut.Lock() + delete(sessions, l.serverkey) + delete(sessions, l.clientkey) + sessionmut.Unlock() + + for _, conn := range conns { + conn.Close() + } + + break + } + } +} + +func (l *session) GetClientInvitationMessage() (message, error) { + invitation := protocol.SessionInvitation{ + Key: []byte(l.clientkey), + Address: nil, + Port: 123, + ServerSocket: false, + } + data, err := invitation.MarshalXDR() + if err != nil { + return message{}, err + } + + return message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypeSessionInvitation, + MessageLength: int32(len(data)), + }, + payload: data, + }, nil +} + +func (l *session) GetServerInvitationMessage() (message, error) { + invitation := protocol.SessionInvitation{ + Key: []byte(l.serverkey), + Address: nil, + Port: 123, + ServerSocket: true, + } + data, err := invitation.MarshalXDR() + if err != nil { + return message{}, err + } + + return message{ + header: protocol.Header{ + Magic: protocol.Magic, + MessageType: protocol.MessageTypeSessionInvitation, + MessageLength: int32(len(data)), + }, + payload: data, + }, nil +} + +func proxy(c1, c2 net.Conn, wg sync.WaitGroup) { + for { + buf := make([]byte, 1024) + c1.SetReadDeadline(time.Now().Add(networkTimeout)) + n, err := c1.Read(buf) + if err != nil { + break + } + + c2.SetWriteDeadline(time.Now().Add(networkTimeout)) + _, err = c2.Write(buf[:n]) + if err != nil { + break + } + } + c1.Close() + c2.Close() + wg.Done() +} diff --git a/cmd/relaysrv/session_listener.go b/cmd/relaysrv/session_listener.go new file mode 100644 index 000000000..b78c4f4b6 --- /dev/null +++ b/cmd/relaysrv/session_listener.go @@ -0,0 +1,59 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "io" + "log" + "net" + "time" +) + +func sessionListener(addr string) { + listener, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalln(err) + } + + for { + conn, err := listener.Accept() + if err != nil { + if debug { + log.Println(err) + } + continue + } + + if debug { + log.Println("Session listener accepted connection from", conn.RemoteAddr()) + } + + go sessionConnectionHandler(conn) + } +} + +func sessionConnectionHandler(conn net.Conn) { + conn.SetReadDeadline(time.Now().Add(messageTimeout)) + key := make([]byte, 32) + + _, err := io.ReadFull(conn, key) + if err != nil { + if debug { + log.Println("Failed to read key", err, conn.RemoteAddr()) + } + conn.Close() + return + } + + ses := findSession(string(key)) + if debug { + log.Println("Key", key, "by", conn.RemoteAddr(), "session", ses) + } + + if ses != nil { + ses.AddConnection(conn) + } else { + conn.Close() + return + } +} diff --git a/cmd/relaysrv/utils.go b/cmd/relaysrv/utils.go new file mode 100644 index 000000000..5388ba32e --- /dev/null +++ b/cmd/relaysrv/utils.go @@ -0,0 +1,53 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "errors" + "net" + "time" +) + +func setTCPOptions(conn net.Conn) error { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return errors.New("Not a TCP connection") + } + if err := tcpConn.SetLinger(0); err != nil { + return err + } + if err := tcpConn.SetNoDelay(true); err != nil { + return err + } + if err := tcpConn.SetKeepAlivePeriod(60 * time.Second); err != nil { + return err + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return nil +} + +func sendMessage(msg message, conn net.Conn) error { + header, err := msg.header.MarshalXDR() + if err != nil { + return err + } + + err = conn.SetWriteDeadline(time.Now().Add(networkTimeout)) + if err != nil { + return err + } + + _, err = conn.Write(header) + if err != nil { + return err + } + + _, err = conn.Write(msg.payload) + if err != nil { + return err + } + + return nil +}