From 01b39ee0955b59abdf1895b18435f1731226fcac Mon Sep 17 00:00:00 2001 From: Dave Henderson Date: Sat, 18 Nov 2017 09:33:49 -0500 Subject: Adding support for stdin: scheme for datasources Signed-off-by: Dave Henderson --- data/datasource.go | 20 ++++++++++++++++++++ data/datasource_test.go | 21 +++++++++++++++++++++ 2 files changed, 41 insertions(+) (limited to 'data') diff --git a/data/datasource.go b/data/datasource.go index 9824e41f..ef693761 100644 --- a/data/datasource.go +++ b/data/datasource.go @@ -3,6 +3,7 @@ package data import ( "errors" "fmt" + "io" "io/ioutil" "log" "mime" @@ -22,6 +23,9 @@ import ( // logFatal is defined so log.Fatal calls can be overridden for testing var logFatalf = log.Fatalf +// stdin - for overriding in tests +var stdin io.Reader + func regExtension(ext, typ string) { err := mime.AddExtensionType(ext, typ) if err != nil { @@ -43,6 +47,7 @@ func init() { addSourceReader("http", readHTTP) addSourceReader("https", readHTTP) addSourceReader("file", readFile) + addSourceReader("stdin", readStdin) addSourceReader("vault", readVault) addSourceReader("consul", readConsul) addSourceReader("consul+http", readConsul) @@ -157,6 +162,9 @@ func ParseSource(value string) (*Source, error) { srcURL = absURL(f) } else if len(parts) == 2 { alias = parts[0] + if parts[1] == "-" { + parts[1] = "stdin://" + } var err error srcURL, err = url.Parse(parts[1]) if err != nil { @@ -296,6 +304,18 @@ func readFile(source *Source, args ...string) ([]byte, error) { return b, nil } +func readStdin(source *Source, args ...string) ([]byte, error) { + if stdin == nil { + stdin = os.Stdin + } + b, err := ioutil.ReadAll(stdin) + if err != nil { + log.Printf("Can't read %v: %#v", stdin, err) + return nil, err + } + return b, nil +} + func readHTTP(source *Source, args ...string) ([]byte, error) { if source.HC == nil { source.HC = &http.Client{Timeout: time.Second * 5} diff --git a/data/datasource_test.go b/data/datasource_test.go index a3b6f31e..ac7b884a 100644 --- a/data/datasource_test.go +++ b/data/datasource_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/blang/vfs" @@ -307,3 +308,23 @@ func TestInclude(t *testing.T) { actual := data.Include("foo") assert.Equal(t, contents, actual) } + +type errorReader struct{} + +func (e errorReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("error") +} + +func TestReadStdin(t *testing.T) { + defer func() { + stdin = nil + }() + stdin = strings.NewReader("foo") + out, err := readStdin(nil) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), out) + + stdin = errorReader{} + _, err = readStdin(nil) + assert.Error(t, err) +} -- cgit v1.2.3