summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDave Henderson <dhenderson@gmail.com>2024-12-16 16:11:27 -0500
committerGitHub <noreply@github.com>2024-12-16 21:11:27 +0000
commitcd74bb8eae53597d8272bcdffdaa34e1b839a700 (patch)
tree8bc9529c053e1561ec8b611f6b9827a3153693ff
parenta69bb645210bd335e2ce3cece1b54af3a286db7a (diff)
fix(fs): Cache data in stdinfs (#2288)
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
-rw-r--r--internal/datafs/stdinfs.go38
-rw-r--r--internal/datafs/stdinfs_test.go51
2 files changed, 85 insertions, 4 deletions
diff --git a/internal/datafs/stdinfs.go b/internal/datafs/stdinfs.go
index 46cb030a..f6a02c51 100644
--- a/internal/datafs/stdinfs.go
+++ b/internal/datafs/stdinfs.go
@@ -18,7 +18,8 @@ func NewStdinFS(_ *url.URL) (fs.FS, error) {
}
type stdinFS struct {
- ctx context.Context
+ ctx context.Context
+ data []byte
}
//nolint:gochecknoglobals
@@ -46,9 +47,15 @@ func (f *stdinFS) Open(name string) (fs.File, error) {
}
}
- stdin := StdinFromContext(f.ctx)
+ if err := f.readData(); err != nil {
+ return nil, &fs.PathError{
+ Op: "open",
+ Path: name,
+ Err: err,
+ }
+ }
- return &stdinFile{name: name, body: stdin}, nil
+ return &stdinFile{name: name, body: bytes.NewReader(f.data)}, nil
}
func (f *stdinFS) ReadFile(name string) ([]byte, error) {
@@ -60,9 +67,32 @@ func (f *stdinFS) ReadFile(name string) ([]byte, error) {
}
}
+ if err := f.readData(); err != nil {
+ return nil, &fs.PathError{
+ Op: "readFile",
+ Path: name,
+ Err: err,
+ }
+ }
+
+ return f.data, nil
+}
+
+func (f *stdinFS) readData() error {
+ if f.data != nil {
+ return nil
+ }
+
stdin := StdinFromContext(f.ctx)
- return io.ReadAll(stdin)
+ b, err := io.ReadAll(stdin)
+ if err != nil {
+ return err
+ }
+
+ f.data = b
+
+ return nil
}
type stdinFile struct {
diff --git a/internal/datafs/stdinfs_test.go b/internal/datafs/stdinfs_test.go
index f8c30a06..f5010a47 100644
--- a/internal/datafs/stdinfs_test.go
+++ b/internal/datafs/stdinfs_test.go
@@ -99,6 +99,57 @@ func TestStdinFS(t *testing.T) {
_, err = f.Read(p)
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
+
+ t.Run("open/read multiple times", func(t *testing.T) {
+ ctx := ContextWithStdin(context.Background(), bytes.NewReader(content))
+ fsys = fsimpl.WithContextFS(ctx, fsys)
+
+ for i := 0; i < 3; i++ {
+ f, err := fsys.Open("foo")
+ require.NoError(t, err)
+
+ b, err := io.ReadAll(f)
+ require.NoError(t, err)
+ require.Equal(t, content, b, "read %d failed", i)
+ }
+ })
+
+ t.Run("readFile multiple times", func(t *testing.T) {
+ ctx := ContextWithStdin(context.Background(), bytes.NewReader(content))
+ fsys = fsimpl.WithContextFS(ctx, fsys)
+
+ for i := 0; i < 3; i++ {
+ b, err := fs.ReadFile(fsys, "foo")
+ require.NoError(t, err)
+ require.Equal(t, content, b, "read %d failed", i)
+ }
+ })
+
+ t.Run("open errors", func(t *testing.T) {
+ ctx := ContextWithStdin(context.Background(), &errorReader{err: fs.ErrPermission})
+
+ fsys, err := NewStdinFS(u)
+ require.NoError(t, err)
+ assert.IsType(t, &stdinFS{}, fsys)
+
+ fsys = fsimpl.WithContextFS(ctx, fsys)
+
+ _, err = fsys.Open("foo")
+ require.ErrorIs(t, err, fs.ErrPermission)
+ })
+
+ t.Run("readFile errors", func(t *testing.T) {
+ ctx := ContextWithStdin(context.Background(), &errorReader{err: fs.ErrPermission})
+
+ fsys, err := NewStdinFS(u)
+ require.NoError(t, err)
+ assert.IsType(t, &stdinFS{}, fsys)
+
+ fsys = fsimpl.WithContextFS(ctx, fsys)
+
+ _, err = fs.ReadFile(fsys, "foo")
+ require.ErrorIs(t, err, fs.ErrPermission)
+ })
}
type errorReader struct {