diff --git a/lib/fs/basicfs.go b/lib/fs/basicfs.go index 32dd9de06..91020ce5b 100644 --- a/lib/fs/basicfs.go +++ b/lib/fs/basicfs.go @@ -297,7 +297,7 @@ func (f *BasicFilesystem) SameFile(fi1, fi2 FileInfo) bool { return false } - return os.SameFile(f1.FileInfo, f2.FileInfo) + return os.SameFile(f1.osFileInfo(), f2.osFileInfo()) } // basicFile implements the fs.File interface on top of an os.File diff --git a/lib/fs/basicfs_fileinfo_unix.go b/lib/fs/basicfs_fileinfo_unix.go index 6af8d6797..c21ba3807 100644 --- a/lib/fs/basicfs_fileinfo_unix.go +++ b/lib/fs/basicfs_fileinfo_unix.go @@ -8,7 +8,10 @@ package fs -import "syscall" +import ( + "os" + "syscall" +) func (e basicFileInfo) Mode() FileMode { return FileMode(e.FileInfo.Mode()) @@ -27,3 +30,9 @@ func (e basicFileInfo) Group() int { } return -1 } + +// fileStat converts e to os.FileInfo that is suitable +// to be passed to os.SameFile. Non-trivial on Windows. +func (e *basicFileInfo) osFileInfo() os.FileInfo { + return e.FileInfo +} diff --git a/lib/fs/basicfs_fileinfo_windows.go b/lib/fs/basicfs_fileinfo_windows.go index 0f0277728..e510b60cd 100644 --- a/lib/fs/basicfs_fileinfo_windows.go +++ b/lib/fs/basicfs_fileinfo_windows.go @@ -56,3 +56,13 @@ func (e basicFileInfo) Owner() int { func (e basicFileInfo) Group() int { return -1 } + +// osFileInfo converts e to os.FileInfo that is suitable +// to be passed to os.SameFile. +func (e *basicFileInfo) osFileInfo() os.FileInfo { + fi := e.FileInfo + if fi, ok := fi.(*dirJunctFileInfo); ok { + return fi.FileInfo + } + return fi +} diff --git a/lib/fs/basicfs_test.go b/lib/fs/basicfs_test.go index 7b23d7883..c900a131d 100644 --- a/lib/fs/basicfs_test.go +++ b/lib/fs/basicfs_test.go @@ -577,3 +577,15 @@ func TestBasicWalkSkipSymlink(t *testing.T) { defer os.RemoveAll(dir) testWalkSkipSymlink(t, FilesystemTypeBasic, dir) } + +func TestWalkTraverseDirJunct(t *testing.T) { + _, dir := setup(t) + defer os.RemoveAll(dir) + testWalkTraverseDirJunct(t, FilesystemTypeBasic, dir) +} + +func TestWalkInfiniteRecursion(t *testing.T) { + _, dir := setup(t) + defer os.RemoveAll(dir) + testWalkInfiniteRecursion(t, FilesystemTypeBasic, dir) +} diff --git a/lib/fs/lstat_regular.go b/lib/fs/lstat_regular.go index 17f41ec12..cd31091a8 100644 --- a/lib/fs/lstat_regular.go +++ b/lib/fs/lstat_regular.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -// +build !linux,!android +// +build !linux,!android,!windows package fs diff --git a/lib/fs/lstat_windows.go b/lib/fs/lstat_windows.go new file mode 100644 index 000000000..e05be7b9e --- /dev/null +++ b/lib/fs/lstat_windows.go @@ -0,0 +1,80 @@ +// Copyright (C) 2015 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +// +build windows + +package fs + +import ( + "fmt" + "os" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +func isDirectoryJunction(path string) (bool, error) { + namep, err := syscall.UTF16PtrFromString(path) + if err != nil { + return false, fmt.Errorf("syscall.UTF16PtrFromString failed with: %s", err) + } + attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS | syscall.FILE_FLAG_OPEN_REPARSE_POINT) + h, err := syscall.CreateFile(namep, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0) + if err != nil { + return false, fmt.Errorf("syscall.CreateFile failed with: %s", err) + } + defer syscall.CloseHandle(h) + + //https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_attribute_tag_info + const fileAttributeTagInfo = 9 + type FILE_ATTRIBUTE_TAG_INFO struct { + FileAttributes uint32 + ReparseTag uint32 + } + + var ti FILE_ATTRIBUTE_TAG_INFO + err = windows.GetFileInformationByHandleEx(windows.Handle(h), fileAttributeTagInfo, (*byte)(unsafe.Pointer(&ti)), uint32(unsafe.Sizeof(ti))) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == windows.ERROR_INVALID_PARAMETER { + // It appears calling GetFileInformationByHandleEx with + // FILE_ATTRIBUTE_TAG_INFO fails on FAT file system with + // ERROR_INVALID_PARAMETER. Clear ti.ReparseTag in that + // instance to indicate no symlinks are possible. + ti.ReparseTag = 0 + } else { + return false, fmt.Errorf("windows.GetFileInformationByHandleEx failed with: %s", err) + } + } + return ti.ReparseTag == windows.IO_REPARSE_TAG_MOUNT_POINT, nil +} + +type dirJunctFileInfo struct { + os.FileInfo +} + +func (fi *dirJunctFileInfo) Mode() os.FileMode { + return fi.FileInfo.Mode() ^ os.ModeSymlink | os.ModeDir +} + +func (fi *dirJunctFileInfo) IsDir() bool { + return true +} + +func underlyingLstat(name string) (os.FileInfo, error) { + var fi, err = os.Lstat(name) + + // NTFS directory junctions are treated as ordinary directories, + // see https://forum.syncthing.net/t/option-to-follow-directory-junctions-symbolic-links/14750 + if err == nil && fi.Mode()&os.ModeSymlink != 0 { + var isJunct bool + isJunct, err = isDirectoryJunction(name) + if err == nil && isJunct { + return &dirJunctFileInfo{fi}, nil + } + } + return fi, err +} diff --git a/lib/fs/walkfs.go b/lib/fs/walkfs.go index e56555932..6582e1ff0 100644 --- a/lib/fs/walkfs.go +++ b/lib/fs/walkfs.go @@ -10,7 +10,37 @@ package fs -import "path/filepath" +import ( + "path/filepath" +) + +type ancestorDirList struct { + list []FileInfo + fs Filesystem +} + +func (ancestors *ancestorDirList) Push(info FileInfo) { + l.Debugf("ancestorDirList: Push '%s'", info.Name()) + ancestors.list = append(ancestors.list, info) +} + +func (ancestors *ancestorDirList) Pop() FileInfo { + aLen := len(ancestors.list) + info := ancestors.list[aLen-1] + l.Debugf("ancestorDirList: Pop '%s'", info.Name()) + ancestors.list = ancestors.list[:aLen-1] + return info +} + +func (ancestors *ancestorDirList) Contains(info FileInfo) bool { + l.Debugf("ancestorDirList: Contains '%s'", info.Name()) + for _, ancestor := range ancestors.list { + if ancestors.fs.SameFile(info, ancestor) { + return true + } + } + return false +} // WalkFunc is the type of the function called for each file or directory // visited by Walk. The path argument contains the argument to Walk as a @@ -37,7 +67,8 @@ func NewWalkFilesystem(next Filesystem) Filesystem { } // walk recursively descends path, calling walkFn. -func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error { +func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc, ancestors *ancestorDirList) error { + l.Debugf("walk: path=%s", path) path, err := Canonicalize(path) if err != nil { return err @@ -55,6 +86,14 @@ func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error return nil } + if !ancestors.Contains(info) { + ancestors.Push(info) + defer ancestors.Pop() + } else { + l.Warnf("Infinite filesystem recursion detected on path '%s', not walking further down", path) + return nil + } + names, err := f.DirNames(path) if err != nil { return walkFn(path, info, err) @@ -68,7 +107,7 @@ func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error return err } } else { - err = f.walk(filename, fileInfo, walkFn) + err = f.walk(filename, fileInfo, walkFn, ancestors) if err != nil { if !fileInfo.IsDir() || err != SkipDir { return err @@ -90,5 +129,6 @@ func (f *walkFilesystem) Walk(root string, walkFn WalkFunc) error { if err != nil { return walkFn(root, nil, err) } - return f.walk(root, info, walkFn) + ancestors := &ancestorDirList{fs: f.Filesystem} + return f.walk(root, info, walkFn, ancestors) } diff --git a/lib/fs/walkfs_test.go b/lib/fs/walkfs_test.go index b09cbf76a..7ec0a01c8 100644 --- a/lib/fs/walkfs_test.go +++ b/lib/fs/walkfs_test.go @@ -7,13 +7,16 @@ package fs import ( + "fmt" + osexec "os/exec" + "path/filepath" "runtime" "testing" ) func testWalkSkipSymlink(t *testing.T, fsType FilesystemType, uri string) { if runtime.GOOS == "windows" { - t.Skip("Symlinks on windows") + t.Skip("Symlinks skipping is not tested on windows") } fs := NewFilesystem(fsType, uri) @@ -39,3 +42,83 @@ func testWalkSkipSymlink(t *testing.T, fsType FilesystemType, uri string) { t.Fatal(err) } } + +func createDirJunct(target string, name string) error { + output, err := osexec.Command("cmd", "/c", "mklink", "/J", name, target).CombinedOutput() + if err != nil { + return fmt.Errorf("Failed to run mklink %v %v: %v %q", name, target, err, output) + } + return nil +} + +func testWalkTraverseDirJunct(t *testing.T, fsType FilesystemType, uri string) { + if runtime.GOOS != "windows" { + t.Skip("Directory junctions are available and tested on windows only") + } + + fs := NewFilesystem(fsType, uri) + + if err := fs.MkdirAll("target/foo", 0); err != nil { + t.Fatal(err) + } + if err := fs.Mkdir("towalk", 0); err != nil { + t.Fatal(err) + } + if err := createDirJunct(filepath.Join(uri, "target"), filepath.Join(uri, "towalk/dirjunct")); err != nil { + t.Fatal(err) + } + traversed := false + if err := fs.Walk("towalk", func(path string, info FileInfo, err error) error { + if err != nil { + t.Fatal(err) + } + if info.Name() == "foo" { + traversed = true + } + return nil + }); err != nil { + t.Fatal(err) + } + if !traversed { + t.Fatal("Directory junction was not traversed") + } +} + +func testWalkInfiniteRecursion(t *testing.T, fsType FilesystemType, uri string) { + if runtime.GOOS != "windows" { + t.Skip("Infinite recursion detection is tested on windows only") + } + + fs := NewFilesystem(fsType, uri) + + if err := fs.MkdirAll("target/foo", 0); err != nil { + t.Fatal(err) + } + if err := fs.Mkdir("towalk", 0); err != nil { + t.Fatal(err) + } + if err := createDirJunct(filepath.Join(uri, "target"), filepath.Join(uri, "towalk/dirjunct")); err != nil { + t.Fatal(err) + } + if err := createDirJunct(filepath.Join(uri, "towalk"), filepath.Join(uri, "target/foo/recurse")); err != nil { + t.Fatal(err) + } + dirjunctCnt := 0 + fooCnt := 0 + if err := fs.Walk("towalk", func(path string, info FileInfo, err error) error { + if err != nil { + t.Fatal(err) + } + if info.Name() == "dirjunct" { + dirjunctCnt++ + } else if info.Name() == "foo" { + fooCnt++ + } + return nil + }); err != nil { + t.Fatal(err) + } + if dirjunctCnt != 2 || fooCnt != 1 { + t.Fatal("Infinite recursion not detected correctly") + } +}