summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorDave Henderson <dhenderson@gmail.com>2021-01-23 15:16:57 -0500
committerDave Henderson <dhenderson@gmail.com>2021-01-23 16:59:01 -0500
commitdab034d8edaabd1c183f99c33815dae151af6ab4 (patch)
tree81e07d096c7215f0d5a04ba7c30f4b553311b76f /internal
parentf20592a0422a12774474e61f28dbcade280a5909 (diff)
Fix race condition in signal handling
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/cmd/main.go25
-rw-r--r--internal/cmd/main_test.go16
2 files changed, 35 insertions, 6 deletions
diff --git a/internal/cmd/main.go b/internal/cmd/main.go
index 5cbbc47a..da02670d 100644
--- a/internal/cmd/main.go
+++ b/internal/cmd/main.go
@@ -24,6 +24,10 @@ func postRunExec(ctx context.Context, cfg *config.Config, stdout, stderr io.Writ
log := zerolog.Ctx(ctx)
log.Debug().Strs("args", args).Msg("running post-exec command")
+ //nolint:govet
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
name := args[0]
args = args[1:]
// nolint: gosec
@@ -35,16 +39,25 @@ func postRunExec(ctx context.Context, cfg *config.Config, stdout, stderr io.Writ
// make sure all signals are propagated
sigs := make(chan os.Signal, 1)
signal.Notify(sigs)
+
+ err := c.Start()
+ if err != nil {
+ return err
+ }
+
go func() {
- // Pass signals to the sub-process
- sig := <-sigs
- if c.Process != nil {
- // nolint: gosec
- _ = c.Process.Signal(sig)
+ select {
+ case sig := <-sigs:
+ // Pass signals to the sub-process
+ if c.Process != nil {
+ // nolint: gosec
+ _ = c.Process.Signal(sig)
+ }
+ case <-ctx.Done():
}
}()
- return c.Run()
+ return c.Wait()
}
return nil
}
diff --git a/internal/cmd/main_test.go b/internal/cmd/main_test.go
index 678d5979..7f8bd6f7 100644
--- a/internal/cmd/main_test.go
+++ b/internal/cmd/main_test.go
@@ -3,8 +3,10 @@ package cmd
import (
"bytes"
"context"
+ "strings"
"testing"
+ "github.com/hairyhenderson/gomplate/v3/internal/config"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
@@ -56,3 +58,17 @@ func TestRunMain(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "hello", stdout.String())
}
+
+func TestPostRunExec(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ cfg := &config.Config{
+ PostExecInput: strings.NewReader("hello world"),
+ PostExec: []string{"cat"},
+ }
+ out := &bytes.Buffer{}
+ err := postRunExec(ctx, cfg, out, out)
+ assert.NoError(t, err)
+ assert.Equal(t, "hello world", out.String())
+}