summaryrefslogtreecommitdiff
path: root/file
diff options
context:
space:
mode:
authorDave Henderson <dhenderson@gmail.com>2019-02-14 22:16:53 -0500
committerDave Henderson <dhenderson@gmail.com>2019-02-15 21:06:55 -0500
commit661b8073b9e77a84ecf0111ad9464b4bfe0bbfe2 (patch)
treea726654f3f83797f173c684965510d508805e5b3 /file
parent6ab3a6f6f3795aa2e7fda89bf4eae1720d5418e8 (diff)
New file.Write function
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'file')
-rw-r--r--file/file.go50
-rw-r--r--file/file_test.go65
2 files changed, 115 insertions, 0 deletions
diff --git a/file/file.go b/file/file.go
index f8fa2839..16067161 100644
--- a/file/file.go
+++ b/file/file.go
@@ -3,6 +3,8 @@ package file
import (
"io/ioutil"
"os"
+ "path/filepath"
+ "strings"
"github.com/pkg/errors"
@@ -43,3 +45,51 @@ func ReadDir(path string) ([]string, error) {
}
return nil, errors.New("file is not a directory")
}
+
+// Write a
+func Write(filename string, content []byte) error {
+ err := assertPathInWD(filename)
+ if err != nil {
+ return errors.Wrapf(err, "failed to open %s", filename)
+ }
+
+ fi, err := os.Stat(filename)
+ if err != nil && !os.IsNotExist(err) {
+ return errors.Wrapf(err, "failed to stat %s", filename)
+ }
+ mode := os.FileMode(0644)
+ if fi != nil {
+ mode = fi.Mode()
+ }
+ inFile, err := fs.OpenFile(filename, os.O_RDWR|os.O_CREATE, mode)
+ if err != nil {
+ return errors.Wrapf(err, "failed to open %s", filename)
+ }
+ n, err := inFile.Write(content)
+ if err != nil {
+ return errors.Wrapf(err, "failed to write %s", filename)
+ }
+ if n != len(content) {
+ return errors.Wrapf(err, "short write on %s (%d bytes)", filename, n)
+ }
+ return nil
+}
+
+func assertPathInWD(filename string) error {
+ wd, err := os.Getwd()
+ if err != nil {
+ return err
+ }
+ f, err := filepath.Abs(filename)
+ if err != nil {
+ return err
+ }
+ r, err := filepath.Rel(wd, f)
+ if err != nil {
+ return err
+ }
+ if strings.HasPrefix(r, "..") {
+ return errors.Errorf("path %s not contained by working directory %s (rel: %s)", filename, wd, r)
+ }
+ return nil
+}
diff --git a/file/file_test.go b/file/file_test.go
index 144c67a5..47c0d542 100644
--- a/file/file_test.go
+++ b/file/file_test.go
@@ -1,8 +1,12 @@
package file
import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
"testing"
+ tfs "github.com/gotestyourself/gotestyourself/fs"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
)
@@ -41,3 +45,64 @@ func TestReadDir(t *testing.T) {
_, err = ReadDir("/tmp/foo")
assert.Error(t, err)
}
+
+func TestWrite(t *testing.T) {
+ oldwd, _ := os.Getwd()
+ defer os.Chdir(oldwd)
+
+ rootDir := tfs.NewDir(t, "gomplate-test")
+ defer rootDir.Remove()
+
+ newwd := rootDir.Join("the", "path", "we", "want")
+ badwd := rootDir.Join("some", "other", "dir")
+ fs.MkdirAll(newwd, 0755)
+ fs.MkdirAll(badwd, 0755)
+ newwd, _ = filepath.EvalSymlinks(newwd)
+ badwd, _ = filepath.EvalSymlinks(badwd)
+
+ err := os.Chdir(newwd)
+ assert.NoError(t, err)
+
+ err = Write("/foo", []byte("Hello world"))
+ assert.Error(t, err)
+
+ rel, err := filepath.Rel(newwd, badwd)
+ assert.NoError(t, err)
+ err = Write(rel, []byte("Hello world"))
+ assert.Error(t, err)
+
+ foopath := filepath.Join(newwd, "foo")
+ err = Write(foopath, []byte("Hello world"))
+ assert.NoError(t, err)
+
+ out, err := ioutil.ReadFile(foopath)
+ assert.NoError(t, err)
+ assert.Equal(t, "Hello world", string(out))
+}
+
+func TestAssertPathInWD(t *testing.T) {
+ oldwd, _ := os.Getwd()
+ defer os.Chdir(oldwd)
+
+ err := assertPathInWD("/tmp")
+ assert.Error(t, err)
+
+ err = assertPathInWD(filepath.Join(oldwd, "subpath"))
+ assert.NoError(t, err)
+
+ err = assertPathInWD("subpath")
+ assert.NoError(t, err)
+
+ err = assertPathInWD("./subpath")
+ assert.NoError(t, err)
+
+ err = assertPathInWD(filepath.Join("..", "bogus"))
+ assert.Error(t, err)
+
+ err = assertPathInWD(filepath.Join("..", "..", "bogus"))
+ assert.Error(t, err)
+
+ base := filepath.Base(oldwd)
+ err = assertPathInWD(filepath.Join("..", base))
+ assert.NoError(t, err)
+}