diff options
Diffstat (limited to 'template.go')
| -rw-r--r-- | template.go | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/template.go b/template.go index b500ce73..448eff84 100644 --- a/template.go +++ b/template.go @@ -9,7 +9,7 @@ import ( "text/template" "github.com/hairyhenderson/gomplate/v3/internal/config" - "github.com/hairyhenderson/gomplate/v3/internal/writers" + "github.com/hairyhenderson/gomplate/v3/internal/iohelpers" "github.com/hairyhenderson/gomplate/v3/tmpl" "github.com/spf13/afero" @@ -101,7 +101,7 @@ func gatherTemplates(cfg *config.Config, outFileNamer func(string) (string, erro // --exec-pipe redirects standard out to the out pipe if cfg.OutWriter != nil { - Stdout = &writers.NopCloser{Writer: cfg.OutWriter} + Stdout = &iohelpers.NopCloser{Writer: cfg.OutWriter} } switch { @@ -227,7 +227,7 @@ func fileToTemplates(inFile, outFile string, mode os.FileMode, modeOverride bool func openOutFile(cfg *config.Config, filename string, mode os.FileMode, modeOverride bool) (out io.WriteCloser, err error) { if cfg.SuppressEmpty { - out = writers.NewEmptySkipper(func() (io.WriteCloser, error) { + out = iohelpers.NewEmptySkipper(func() (io.WriteCloser, error) { if filename == "-" { return Stdout, nil } @@ -260,14 +260,19 @@ func createOutFile(filename string, mode os.FileMode, modeOverride bool) (out io } // if the output file already exists, we'll use a SameSkipper - f, err := fs.OpenFile(filename, os.O_RDONLY, mode.Perm()) + fi, err := fs.Stat(filename) if err != nil { - // likely means the file just doesn't exist - open's error will be more useful - return open() + // likely means the file just doesn't exist - further errors will be more useful + return iohelpers.LazyWriteCloser(open), nil } - out = writers.SameSkipper(f, func() (io.WriteCloser, error) { - return open() - }) + if fi.IsDir() { + // error because this is a directory + return nil, isDirError(fi.Name()) + } + + out = iohelpers.SameSkipper(iohelpers.LazyReadCloser(func() (io.ReadCloser, error) { + return fs.OpenFile(filename, os.O_RDONLY, mode.Perm()) + }), open) return out, err } @@ -280,14 +285,14 @@ func readInput(filename string) (string, error) { } else { inFile, err = fs.OpenFile(filename, os.O_RDONLY, 0) if err != nil { - return "", fmt.Errorf("failed to open %s\n%v", filename, err) + return "", fmt.Errorf("failed to open %s: %w", filename, err) } // nolint: errcheck defer inFile.Close() } bytes, err := ioutil.ReadAll(inFile) if err != nil { - err = fmt.Errorf("read failed for %s\n%v", filename, err) + err = fmt.Errorf("read failed for %s: %w", filename, err) return "", err } return string(bytes), nil |
