From 661b8073b9e77a84ecf0111ad9464b4bfe0bbfe2 Mon Sep 17 00:00:00 2001 From: Dave Henderson Date: Thu, 14 Feb 2019 22:16:53 -0500 Subject: New file.Write function Signed-off-by: Dave Henderson --- file/file.go | 50 ++++++++++++++++++++++++++++++++++++++++++ file/file_test.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) (limited to 'file') diff --git a/file/file.go b/file/file.go index f8fa2839..16067161 100644 --- a/file/file.go +++ b/file/file.go @@ -3,6 +3,8 @@ package file import ( "io/ioutil" "os" + "path/filepath" + "strings" "github.com/pkg/errors" @@ -43,3 +45,51 @@ func ReadDir(path string) ([]string, error) { } return nil, errors.New("file is not a directory") } + +// Write a +func Write(filename string, content []byte) error { + err := assertPathInWD(filename) + if err != nil { + return errors.Wrapf(err, "failed to open %s", filename) + } + + fi, err := os.Stat(filename) + if err != nil && !os.IsNotExist(err) { + return errors.Wrapf(err, "failed to stat %s", filename) + } + mode := os.FileMode(0644) + if fi != nil { + mode = fi.Mode() + } + inFile, err := fs.OpenFile(filename, os.O_RDWR|os.O_CREATE, mode) + if err != nil { + return errors.Wrapf(err, "failed to open %s", filename) + } + n, err := inFile.Write(content) + if err != nil { + return errors.Wrapf(err, "failed to write %s", filename) + } + if n != len(content) { + return errors.Wrapf(err, "short write on %s (%d bytes)", filename, n) + } + return nil +} + +func assertPathInWD(filename string) error { + wd, err := os.Getwd() + if err != nil { + return err + } + f, err := filepath.Abs(filename) + if err != nil { + return err + } + r, err := filepath.Rel(wd, f) + if err != nil { + return err + } + if strings.HasPrefix(r, "..") { + return errors.Errorf("path %s not contained by working directory %s (rel: %s)", filename, wd, r) + } + return nil +} diff --git a/file/file_test.go b/file/file_test.go index 144c67a5..47c0d542 100644 --- a/file/file_test.go +++ b/file/file_test.go @@ -1,8 +1,12 @@ package file import ( + "io/ioutil" + "os" + "path/filepath" "testing" + tfs "github.com/gotestyourself/gotestyourself/fs" "github.com/spf13/afero" "github.com/stretchr/testify/assert" ) @@ -41,3 +45,64 @@ func TestReadDir(t *testing.T) { _, err = ReadDir("/tmp/foo") assert.Error(t, err) } + +func TestWrite(t *testing.T) { + oldwd, _ := os.Getwd() + defer os.Chdir(oldwd) + + rootDir := tfs.NewDir(t, "gomplate-test") + defer rootDir.Remove() + + newwd := rootDir.Join("the", "path", "we", "want") + badwd := rootDir.Join("some", "other", "dir") + fs.MkdirAll(newwd, 0755) + fs.MkdirAll(badwd, 0755) + newwd, _ = filepath.EvalSymlinks(newwd) + badwd, _ = filepath.EvalSymlinks(badwd) + + err := os.Chdir(newwd) + assert.NoError(t, err) + + err = Write("/foo", []byte("Hello world")) + assert.Error(t, err) + + rel, err := filepath.Rel(newwd, badwd) + assert.NoError(t, err) + err = Write(rel, []byte("Hello world")) + assert.Error(t, err) + + foopath := filepath.Join(newwd, "foo") + err = Write(foopath, []byte("Hello world")) + assert.NoError(t, err) + + out, err := ioutil.ReadFile(foopath) + assert.NoError(t, err) + assert.Equal(t, "Hello world", string(out)) +} + +func TestAssertPathInWD(t *testing.T) { + oldwd, _ := os.Getwd() + defer os.Chdir(oldwd) + + err := assertPathInWD("/tmp") + assert.Error(t, err) + + err = assertPathInWD(filepath.Join(oldwd, "subpath")) + assert.NoError(t, err) + + err = assertPathInWD("subpath") + assert.NoError(t, err) + + err = assertPathInWD("./subpath") + assert.NoError(t, err) + + err = assertPathInWD(filepath.Join("..", "bogus")) + assert.Error(t, err) + + err = assertPathInWD(filepath.Join("..", "..", "bogus")) + assert.Error(t, err) + + base := filepath.Base(oldwd) + err = assertPathInWD(filepath.Join("..", base)) + assert.NoError(t, err) +} -- cgit v1.2.3