diff options
| author | Dave Henderson <dhenderson@gmail.com> | 2021-01-23 15:16:57 -0500 |
|---|---|---|
| committer | Dave Henderson <dhenderson@gmail.com> | 2021-01-23 16:59:01 -0500 |
| commit | dab034d8edaabd1c183f99c33815dae151af6ab4 (patch) | |
| tree | 81e07d096c7215f0d5a04ba7c30f4b553311b76f /internal | |
| parent | f20592a0422a12774474e61f28dbcade280a5909 (diff) | |
Fix race condition in signal handling
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cmd/main.go | 25 | ||||
| -rw-r--r-- | internal/cmd/main_test.go | 16 |
2 files changed, 35 insertions, 6 deletions
diff --git a/internal/cmd/main.go b/internal/cmd/main.go index 5cbbc47a..da02670d 100644 --- a/internal/cmd/main.go +++ b/internal/cmd/main.go @@ -24,6 +24,10 @@ func postRunExec(ctx context.Context, cfg *config.Config, stdout, stderr io.Writ log := zerolog.Ctx(ctx) log.Debug().Strs("args", args).Msg("running post-exec command") + //nolint:govet + ctx, cancel := context.WithCancel(ctx) + defer cancel() + name := args[0] args = args[1:] // nolint: gosec @@ -35,16 +39,25 @@ func postRunExec(ctx context.Context, cfg *config.Config, stdout, stderr io.Writ // make sure all signals are propagated sigs := make(chan os.Signal, 1) signal.Notify(sigs) + + err := c.Start() + if err != nil { + return err + } + go func() { - // Pass signals to the sub-process - sig := <-sigs - if c.Process != nil { - // nolint: gosec - _ = c.Process.Signal(sig) + select { + case sig := <-sigs: + // Pass signals to the sub-process + if c.Process != nil { + // nolint: gosec + _ = c.Process.Signal(sig) + } + case <-ctx.Done(): } }() - return c.Run() + return c.Wait() } return nil } diff --git a/internal/cmd/main_test.go b/internal/cmd/main_test.go index 678d5979..7f8bd6f7 100644 --- a/internal/cmd/main_test.go +++ b/internal/cmd/main_test.go @@ -3,8 +3,10 @@ package cmd import ( "bytes" "context" + "strings" "testing" + "github.com/hairyhenderson/gomplate/v3/internal/config" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" ) @@ -56,3 +58,17 @@ func TestRunMain(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "hello", stdout.String()) } + +func TestPostRunExec(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cfg := &config.Config{ + PostExecInput: strings.NewReader("hello world"), + PostExec: []string{"cat"}, + } + out := &bytes.Buffer{} + err := postRunExec(ctx, cfg, out, out) + assert.NoError(t, err) + assert.Equal(t, "hello world", out.String()) +} |
