diff --git a/cmd/discosrv/main.go b/cmd/discosrv/main.go index cc8660fdf..817c716f6 100644 --- a/cmd/discosrv/main.go +++ b/cmd/discosrv/main.go @@ -25,6 +25,7 @@ var ( certFile = "cert.pem" keyFile = "key.pem" debug = false + useHttp = false ) func main() { @@ -48,15 +49,20 @@ func main() { flag.StringVar(&certFile, "cert", certFile, "Certificate file") flag.StringVar(&keyFile, "key", keyFile, "Key file") flag.BoolVar(&debug, "debug", debug, "Debug") + flag.BoolVar(&useHttp, "http", useHttp, "Listen on HTTP (behind an HTTPS proxy)") flag.Parse() - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - log.Fatalln("Failed to load X509 key pair:", err) - } + var cert tls.Certificate + var err error + if !useHttp { + cert, err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } - devID := protocol.NewDeviceID(cert.Certificate[0]) - log.Println("Server device ID is", devID) + devID := protocol.NewDeviceID(cert.Certificate[0]) + log.Println("Server device ID is", devID) + } db, err := sql.Open(backend, dsn) if err != nil { diff --git a/cmd/discosrv/querysrv.go b/cmd/discosrv/querysrv.go index e06d7d557..6b30d777e 100644 --- a/cmd/discosrv/querysrv.go +++ b/cmd/discosrv/querysrv.go @@ -3,9 +3,11 @@ package main import ( + "bytes" "crypto/tls" "database/sql" "encoding/json" + "encoding/pem" "log" "net" "net/http" @@ -39,39 +41,47 @@ type annRelay struct { func (s *querysrv) Serve() { s.limiter = lru.New(lruSize) - tlsCfg := &tls.Config{ - Certificates: []tls.Certificate{s.cert}, - ClientAuth: tls.RequestClientCert, - SessionTicketsDisabled: true, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - }, + if useHttp { + listener, err := net.Listen("tcp", s.addr) + if err != nil { + log.Println("Listen:", err) + return + } + s.listener = listener + } else { + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{s.cert}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } + + tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg) + if err != nil { + log.Println("Listen:", err) + return + } + s.listener = tlsListener } http.HandleFunc("/", s.handler) http.HandleFunc("/ping", handlePing) - tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg) - if err != nil { - log.Println("Listen:", err) - return - } - - s.listener = tlsListener - srv := &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, MaxHeaderBytes: 1 << 10, } - if err := srv.Serve(tlsListener); err != nil { + if err := srv.Serve(s.listener); err != nil { log.Println("Serve:", err) } } @@ -81,16 +91,22 @@ func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) { log.Println(req.Method, req.URL) } - remoteAddr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) - if err != nil { - log.Println("remoteAddr:", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return + var remoteIP net.IP + if useHttp { + remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For")) + } else { + addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) + if err != nil { + log.Println("remoteAddr:", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + remoteIP = addr.IP } - if s.limit(remoteAddr.IP) { + if s.limit(remoteIP) { if debug { - log.Println(remoteAddr.IP, "is limited") + log.Println(remoteIP, "is limited") } w.Header().Set("Retry-After", "60") http.Error(w, "Too Many Requests", 429) @@ -101,7 +117,7 @@ func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) { case "GET": s.handleGET(w, req) case "POST": - s.handlePOST(w, req) + s.handlePOST(remoteIP, w, req) default: globalStats.Error() http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) @@ -109,15 +125,6 @@ func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) { } func (s *querysrv) handleGET(w http.ResponseWriter, req *http.Request) { - if req.TLS == nil { - if debug { - log.Println(req.Method, req.URL, "not TLS") - } - globalStats.Error() - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device")) if err != nil { if debug { @@ -159,17 +166,9 @@ func (s *querysrv) handleGET(w http.ResponseWriter, req *http.Request) { json.NewEncoder(w).Encode(ann) } -func (s *querysrv) handlePOST(w http.ResponseWriter, req *http.Request) { - if req.TLS == nil { - if debug { - log.Println(req.Method, req.URL, "not TLS") - } - globalStats.Error() - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - if len(req.TLS.PeerCertificates) == 0 { +func (s *querysrv) handlePOST(remoteIP net.IP, w http.ResponseWriter, req *http.Request) { + rawCert := certificateBytes(req) + if rawCert == nil { if debug { log.Println(req.Method, req.URL, "no certificates") } @@ -188,22 +187,14 @@ func (s *querysrv) handlePOST(w http.ResponseWriter, req *http.Request) { return } - remoteAddr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) - if err != nil { - log.Println("remoteAddr:", err) - globalStats.Error() - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - deviceID := protocol.NewDeviceID(req.TLS.PeerCertificates[0].Raw) + deviceID := protocol.NewDeviceID(rawCert) // handleAnnounce returns *two* errors. The first indicates a problem with // something the client posted to us. We should return a 400 Bad Request // and not worry about it. The second indicates that the request was fine, // but something internal fucked up. We should log it and respond with a // more apologetic 500 Internal Server Error. - userErr, internalErr := s.handleAnnounce(remoteAddr.IP, deviceID, ann.Direct, ann.Relays) + userErr, internalErr := s.handleAnnounce(remoteIP, deviceID, ann.Direct, ann.Relays) if userErr != nil { if debug { log.Println(req.Method, req.URL, userErr) @@ -396,3 +387,32 @@ func (s *querysrv) getRelays(device protocol.DeviceID) ([]annRelay, error) { func handlePing(w http.ResponseWriter, r *http.Request) { w.WriteHeader(204) } + +func certificateBytes(req *http.Request) []byte { + if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 { + return req.TLS.PeerCertificates[0].Raw + } + + if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" { + bs := []byte(hdr) + // The certificate is in PEM format but with spaces for newlines. We + // need to reinstate the newlines for the PEM decoder. But we need to + // leave the spaces in the BEGIN and END lines - the first and last + // space - alone. + firstSpace := bytes.Index(bs, []byte(" ")) + lastSpace := bytes.LastIndex(bs, []byte(" ")) + for i := firstSpace + 1; i < lastSpace; i++ { + if bs[i] == ' ' { + bs[i] = '\n' + } + } + block, _ := pem.Decode(bs) + if block == nil { + // Decoding failed + return nil + } + return block.Bytes + } + + return nil +}