lib/syncthing: Clean up / refactor LoadOrGenerateCertificate() utility function. (#8025)

LoadOrGenerateCertificate() takes two file path arguments, but then
uses the locations package to determine the actual path.  Fix that
with a minimally invasive change, by using the arguments instead.
Factor out GenerateCertificate().

The only caller of this function is cmd/syncthing, which passes the
same values, so this is technically a no-op.

* lib/tlsutil: Make storing generated certificate optional.  Avoid
  temporary cert and key files in tests, keep cert in memory.
This commit is contained in:
André Colomb 2021-11-07 23:59:48 +01:00 committed by GitHub
parent db15e52743
commit ec8a748514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 78 deletions

View File

@ -49,16 +49,13 @@ import (
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/svcutil" "github.com/syncthing/syncthing/lib/svcutil"
"github.com/syncthing/syncthing/lib/syncthing" "github.com/syncthing/syncthing/lib/syncthing"
"github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/syncthing/lib/upgrade" "github.com/syncthing/syncthing/lib/upgrade"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const ( const (
tlsDefaultCommonName = "syncthing" sigTerm = syscall.Signal(15)
deviceCertLifetimeDays = 20 * 365
sigTerm = syscall.Signal(15)
) )
const ( const (
@ -442,7 +439,7 @@ func generate(generateDir string, noDefaultFolder bool) error {
if err == nil { if err == nil {
l.Warnln("Key exists; will not overwrite.") l.Warnln("Key exists; will not overwrite.")
} else { } else {
cert, err = tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays) cert, err = syncthing.GenerateCertificate(certFile, keyFile)
if err != nil { if err != nil {
return errors.Wrap(err, "create certificate") return errors.Wrap(err, "create certificate")
} }

View File

@ -1209,15 +1209,9 @@ func TestPrefixMatch(t *testing.T) {
} }
func TestShouldRegenerateCertificate(t *testing.T) { func TestShouldRegenerateCertificate(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
// Self signed certificates expiring in less than a month are errored so we // Self signed certificates expiring in less than a month are errored so we
// can regenerate in time. // can regenerate in time.
crt, err := tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 29) crt, err := tlsutil.NewCertificateInMemory("foo.example.com", 29)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1226,7 +1220,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
} }
// Certificates with at least 31 days of life left are fine. // Certificates with at least 31 days of life left are fine.
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 31) crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 31)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1236,7 +1230,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
// Certificates with too long an expiry time are not allowed on macOS // Certificates with too long an expiry time are not allowed on macOS
crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 1000) crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 1000)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -11,11 +11,9 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -470,21 +468,9 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
} }
func mustGetCert(b *testing.B) tls.Certificate { func mustGetCert(b *testing.B) tls.Certificate {
f1, err := ioutil.TempFile("", "") cert, err := tlsutil.NewCertificateInMemory("bench", 10)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
f1.Close()
f2, err := ioutil.TempFile("", "")
if err != nil {
b.Fatal(err)
}
f2.Close()
cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10)
if err != nil {
b.Fatal(err)
}
_ = os.Remove(f1.Name())
_ = os.Remove(f2.Name())
return cert return cert
} }

View File

@ -107,13 +107,8 @@ func TestGlobalOverHTTP(t *testing.T) {
} }
func TestGlobalOverHTTPS(t *testing.T) { func TestGlobalOverHTTPS(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing")
if err != nil {
t.Fatal(err)
}
// Generate a server certificate. // Generate a server certificate.
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30) cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -172,13 +167,8 @@ func TestGlobalOverHTTPS(t *testing.T) {
} }
func TestGlobalAnnounce(t *testing.T) { func TestGlobalAnnounce(t *testing.T) {
dir, err := ioutil.TempDir("", "syncthing")
if err != nil {
t.Fatal(err)
}
// Generate a server certificate. // Generate a server certificate.
cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30) cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -9,7 +9,6 @@ package syncthing
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
@ -57,13 +56,7 @@ func TestShortIDCheck(t *testing.T) {
} }
func TestStartupFail(t *testing.T) { func TestStartupFail(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "syncthing-TestStartupFail-") cert, err := tlsutil.NewCertificateInMemory("syncthing", 365)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
cert, err := tlsutil.NewCertificate(filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key"), "syncthing", 365)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -25,22 +25,18 @@ import (
) )
func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) { func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair( cert, err := tls.LoadX509KeyPair(certFile, keyFile)
locations.Get(locations.CertFile),
locations.Get(locations.KeyFile),
)
if err != nil { if err != nil {
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName) return GenerateCertificate(certFile, keyFile)
return tlsutil.NewCertificate(
locations.Get(locations.CertFile),
locations.Get(locations.KeyFile),
tlsDefaultCommonName,
deviceCertLifetimeDays,
)
} }
return cert, nil return cert, nil
} }
func GenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
return tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
}
func DefaultConfig(path string, myID protocol.DeviceID, evLogger events.Logger, noDefaultFolder bool) (config.Wrapper, error) { func DefaultConfig(path string, myID protocol.DeviceID, evLogger events.Logger, noDefaultFolder bool) (config.Wrapper, error) {
newCfg, err := config.NewWithFreePorts(myID) newCfg, err := config.NewWithFreePorts(myID)
if err != nil { if err != nil {

View File

@ -86,11 +86,11 @@ func SecureDefaultWithTLS12() *tls.Config {
} }
} }
// NewCertificate generates and returns a new TLS certificate. // generateCertificate generates a PEM formatted key pair and self-signed certificate in memory.
func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls.Certificate, error) { func generateCertificate(commonName string, lifetimeDays int) (*pem.Block, *pem.Block, error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil { if err != nil {
return tls.Certificate{}, errors.Wrap(err, "generate key") return nil, nil, errors.Wrap(err, "generate key")
} }
notBefore := time.Now().Truncate(24 * time.Hour) notBefore := time.Now().Truncate(24 * time.Hour)
@ -117,19 +117,33 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil { if err != nil {
return tls.Certificate{}, errors.Wrap(err, "create cert") return nil, nil, errors.Wrap(err, "create cert")
}
certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}
keyBlock, err := pemBlockForKey(priv)
if err != nil {
return nil, nil, errors.Wrap(err, "save key")
}
return certBlock, keyBlock, nil
}
// NewCertificate generates and returns a new TLS certificate, saved to the given PEM files.
func NewCertificate(certFile, keyFile string, commonName string, lifetimeDays int) (tls.Certificate, error) {
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
if err != nil {
return tls.Certificate{}, err
} }
certOut, err := os.Create(certFile) certOut, err := os.Create(certFile)
if err != nil { if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert") return tls.Certificate{}, errors.Wrap(err, "save cert")
} }
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) if err = pem.Encode(certOut, certBlock); err != nil {
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert") return tls.Certificate{}, errors.Wrap(err, "save cert")
} }
err = certOut.Close() if err = certOut.Close(); err != nil {
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save cert") return tls.Certificate{}, errors.Wrap(err, "save cert")
} }
@ -137,22 +151,24 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
if err != nil { if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key") return tls.Certificate{}, errors.Wrap(err, "save key")
} }
if err = pem.Encode(keyOut, keyBlock); err != nil {
block, err := pemBlockForKey(priv) return tls.Certificate{}, errors.Wrap(err, "save key")
if err != nil { }
if err = keyOut.Close(); err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key") return tls.Certificate{}, errors.Wrap(err, "save key")
} }
err = pem.Encode(keyOut, block) return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
}
// NewCertificateInMemory generates and returns a new TLS certificate, kept only in memory.
func NewCertificateInMemory(commonName string, lifetimeDays int) (tls.Certificate, error) {
certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
if err != nil { if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key") return tls.Certificate{}, err
}
err = keyOut.Close()
if err != nil {
return tls.Certificate{}, errors.Wrap(err, "save key")
} }
return tls.LoadX509KeyPair(certFile, keyFile) return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
} }
type DowngradingListener struct { type DowngradingListener struct {