syncthing/cmd/relaysrv/session.go

199 lines
3.6 KiB
Go
Raw Normal View History

2015-06-24 13:39:46 +02:00
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package main
import (
"crypto/rand"
2015-06-28 02:52:01 +02:00
"encoding/hex"
"fmt"
"log"
2015-06-24 13:39:46 +02:00
"net"
"sync"
"time"
"github.com/syncthing/relaysrv/protocol"
2015-06-28 02:52:01 +02:00
syncthingprotocol "github.com/syncthing/protocol"
2015-06-24 13:39:46 +02:00
)
var (
2015-06-28 02:52:01 +02:00
sessionMut = sync.Mutex{}
2015-06-24 13:39:46 +02:00
sessions = make(map[string]*session, 0)
)
type session struct {
2015-06-28 02:52:01 +02:00
serverkey []byte
clientkey []byte
2015-06-24 13:39:46 +02:00
conns chan net.Conn
}
func newSession() *session {
serverkey := make([]byte, 32)
_, err := rand.Read(serverkey)
if err != nil {
return nil
}
clientkey := make([]byte, 32)
_, err = rand.Read(clientkey)
if err != nil {
return nil
}
2015-06-28 02:52:01 +02:00
ses := &session{
serverkey: serverkey,
clientkey: clientkey,
2015-06-24 13:39:46 +02:00
conns: make(chan net.Conn),
}
2015-06-28 02:52:01 +02:00
if debug {
log.Println("New session", ses)
}
sessionMut.Lock()
sessions[string(ses.serverkey)] = ses
sessions[string(ses.clientkey)] = ses
sessionMut.Unlock()
return ses
2015-06-24 13:39:46 +02:00
}
func findSession(key string) *session {
2015-06-28 02:52:01 +02:00
sessionMut.Lock()
defer sessionMut.Unlock()
2015-06-24 13:39:46 +02:00
lob, ok := sessions[key]
if !ok {
return nil
}
delete(sessions, key)
return lob
}
2015-06-28 02:52:01 +02:00
func (s *session) AddConnection(conn net.Conn) bool {
if debug {
log.Println("New connection for", s, "from", conn.RemoteAddr())
}
2015-06-24 13:39:46 +02:00
select {
2015-06-28 02:52:01 +02:00
case s.conns <- conn:
return true
2015-06-24 13:39:46 +02:00
default:
}
2015-06-28 02:52:01 +02:00
return false
2015-06-24 13:39:46 +02:00
}
2015-06-28 02:52:01 +02:00
func (s *session) Serve() {
2015-06-24 13:39:46 +02:00
timedout := time.After(messageTimeout)
2015-06-28 02:52:01 +02:00
if debug {
log.Println("Session", s, "serving")
}
2015-06-24 13:39:46 +02:00
conns := make([]net.Conn, 0, 2)
for {
select {
2015-06-28 02:52:01 +02:00
case conn := <-s.conns:
2015-06-24 13:39:46 +02:00
conns = append(conns, conn)
if len(conns) < 2 {
continue
}
2015-06-28 02:52:01 +02:00
close(s.conns)
2015-06-24 13:39:46 +02:00
2015-06-28 02:52:01 +02:00
if debug {
log.Println("Session", s, "starting between", conns[0].RemoteAddr(), conns[1].RemoteAddr())
}
2015-06-24 13:39:46 +02:00
2015-06-28 02:52:01 +02:00
wg := sync.WaitGroup{}
2015-06-24 13:39:46 +02:00
wg.Add(2)
2015-06-28 02:52:01 +02:00
errors := make(chan error, 2)
go func() {
errors <- proxy(conns[0], conns[1])
wg.Done()
}()
go func() {
errors <- proxy(conns[1], conns[0])
wg.Done()
}()
2015-06-24 13:39:46 +02:00
wg.Wait()
2015-06-28 02:52:01 +02:00
if debug {
log.Println("Session", s, "ended, outcomes:", <-errors, <-errors)
}
goto done
2015-06-24 13:39:46 +02:00
case <-timedout:
2015-06-28 02:52:01 +02:00
if debug {
log.Println("Session", s, "timed out")
2015-06-24 13:39:46 +02:00
}
2015-06-28 02:52:01 +02:00
goto done
2015-06-24 13:39:46 +02:00
}
}
2015-06-28 02:52:01 +02:00
done:
sessionMut.Lock()
delete(sessions, string(s.serverkey))
delete(sessions, string(s.clientkey))
sessionMut.Unlock()
for _, conn := range conns {
conn.Close()
}
if debug {
log.Println("Session", s, "stopping")
}
2015-06-24 13:39:46 +02:00
}
2015-06-28 02:52:01 +02:00
func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
return protocol.SessionInvitation{
From: from[:],
Key: []byte(s.clientkey),
Address: sessionAddress,
Port: sessionPort,
2015-06-24 13:39:46 +02:00
ServerSocket: false,
}
}
2015-06-28 02:52:01 +02:00
func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
return protocol.SessionInvitation{
From: from[:],
Key: []byte(s.serverkey),
Address: sessionAddress,
Port: sessionPort,
2015-06-24 13:39:46 +02:00
ServerSocket: true,
}
}
2015-06-28 02:52:01 +02:00
func proxy(c1, c2 net.Conn) error {
if debug {
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
}
buf := make([]byte, 1024)
2015-06-24 13:39:46 +02:00
for {
c1.SetReadDeadline(time.Now().Add(networkTimeout))
2015-06-28 02:52:01 +02:00
n, err := c1.Read(buf[0:])
2015-06-24 13:39:46 +02:00
if err != nil {
2015-06-28 02:52:01 +02:00
return err
}
if debug {
log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
2015-06-24 13:39:46 +02:00
}
c2.SetWriteDeadline(time.Now().Add(networkTimeout))
_, err = c2.Write(buf[:n])
if err != nil {
2015-06-28 02:52:01 +02:00
return err
2015-06-24 13:39:46 +02:00
}
}
2015-06-28 02:52:01 +02:00
}
func (s *session) String() string {
return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
2015-06-24 13:39:46 +02:00
}