summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEng Zer Jun <engzerjun@gmail.com>2023-12-19 10:08:45 +0800
committerGitHub <noreply@github.com>2023-12-19 02:08:45 +0000
commit483af656e6ec51ff1263743eb4957d7d70e6c51f (patch)
treeb85def06f7fbb7e08720a021f0d2a97a7c065907
parente6835bcb487cf5795a53561c3a50bf2090847714 (diff)
test: use `t.Setenv` to set env vars in tests (#1940)
* test: use `t.Setenv` to set env vars in tests This commit replaces `os.Setenv` with `t.Setenv` in tests. The environment variable is automatically restored to its original value when the test and all its subtests complete. Reference: https://pkg.go.dev/testing#T.Setenv Signed-off-by: Eng Zer Jun <engzerjun@gmail.com> * minor adjustments Signed-off-by: Dave Henderson <dhenderson@gmail.com> --------- Signed-off-by: Eng Zer Jun <engzerjun@gmail.com> Signed-off-by: Dave Henderson <dhenderson@gmail.com> Co-authored-by: Dave Henderson <dhenderson@gmail.com>
-rw-r--r--aws/ec2info_test.go96
-rw-r--r--context_test.go8
-rw-r--r--data/datasource_blob_test.go69
-rw-r--r--data/datasource_env_test.go7
-rw-r--r--data/datasource_git_test.go45
-rw-r--r--internal/cmd/config_test.go56
-rw-r--r--internal/cmd/logger_test.go3
-rw-r--r--libkv/consul_test.go226
-rw-r--r--render_test.go3
-rw-r--r--vault/auth_test.go7
-rw-r--r--vault/vault_test.go3
11 files changed, 264 insertions, 259 deletions
diff --git a/aws/ec2info_test.go b/aws/ec2info_test.go
index 2a9efefe..497f50d1 100644
--- a/aws/ec2info_test.go
+++ b/aws/ec2info_test.go
@@ -133,64 +133,64 @@ func TestNewEc2Info(t *testing.T) {
}
func TestGetRegion(t *testing.T) {
- oldReg, ok := os.LookupEnv("AWS_REGION")
- if ok {
- defer os.Setenv("AWS_REGION", oldReg)
- }
- oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION")
- if ok {
- defer os.Setenv("AWS_REGION", oldDefReg)
- }
-
- os.Setenv("AWS_REGION", "kalamazoo")
+ // unset AWS region env vars for clean tests
+ os.Unsetenv("AWS_REGION")
os.Unsetenv("AWS_DEFAULT_REGION")
- region, err := getRegion()
- require.NoError(t, err)
- assert.Empty(t, region)
- os.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
- os.Unsetenv("AWS_REGION")
- region, err = getRegion()
- require.NoError(t, err)
- assert.Empty(t, region)
+ t.Run("with AWS_REGION set", func(t *testing.T) {
+ t.Setenv("AWS_REGION", "kalamazoo")
+ region, err := getRegion()
+ require.NoError(t, err)
+ assert.Empty(t, region)
+ })
- os.Unsetenv("AWS_DEFAULT_REGION")
- metaClient := NewDummyEc2Meta()
- region, err = getRegion(metaClient)
- require.NoError(t, err)
- assert.Equal(t, "unknown", region)
+ t.Run("with AWS_DEFAULT_REGION set", func(t *testing.T) {
+ t.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
+ region, err := getRegion()
+ require.NoError(t, err)
+ assert.Empty(t, region)
+ })
- ec2meta := MockEC2Meta(nil, nil, "us-east-1")
+ t.Run("with no AWS_REGION, AWS_DEFAULT_REGION set", func(t *testing.T) {
+ metaClient := NewDummyEc2Meta()
+ region, err := getRegion(metaClient)
+ require.NoError(t, err)
+ assert.Equal(t, "unknown", region)
+ })
- region, err = getRegion(ec2meta)
- require.NoError(t, err)
- assert.Equal(t, "us-east-1", region)
+ t.Run("infer from EC2 metadata", func(t *testing.T) {
+ ec2meta := MockEC2Meta(nil, nil, "us-east-1")
+ region, err := getRegion(ec2meta)
+ require.NoError(t, err)
+ assert.Equal(t, "us-east-1", region)
+ })
}
func TestGetClientOptions(t *testing.T) {
- oldVar, ok := os.LookupEnv("AWS_TIMEOUT")
- if ok {
- defer os.Setenv("AWS_TIMEOUT", oldVar)
- }
-
co := GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co)
- os.Setenv("AWS_TIMEOUT", "42")
- // reset the Once
- coInit = sync.Once{}
- co = GetClientOptions()
- assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
-
- os.Setenv("AWS_TIMEOUT", "123")
- // without resetting the Once, expect to be reused
- co = GetClientOptions()
- assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
-
- os.Setenv("AWS_TIMEOUT", "foo")
- // reset the Once
- coInit = sync.Once{}
- assert.Panics(t, func() {
- GetClientOptions()
+ t.Run("valid AWS_TIMEOUT, first call", func(t *testing.T) {
+ t.Setenv("AWS_TIMEOUT", "42")
+ // reset the Once
+ coInit = sync.Once{}
+ co = GetClientOptions()
+ assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
+ })
+
+ t.Run("valid AWS_TIMEOUT, non-first call", func(t *testing.T) {
+ t.Setenv("AWS_TIMEOUT", "123")
+ // without resetting the Once, expect to be reused
+ co = GetClientOptions()
+ assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
+ })
+
+ t.Run("invalid AWS_TIMEOUT", func(t *testing.T) {
+ t.Setenv("AWS_TIMEOUT", "foo")
+ // reset the Once
+ coInit = sync.Once{}
+ assert.Panics(t, func() {
+ GetClientOptions()
+ })
})
}
diff --git a/context_test.go b/context_test.go
index 8fd4da71..aa3a4468 100644
--- a/context_test.go
+++ b/context_test.go
@@ -21,7 +21,7 @@ func TestEnvMapifiesEnvironment(t *testing.T) {
func TestEnvGetsUpdatedEnvironment(t *testing.T) {
c := &tmplctx{}
assert.Empty(t, c.Env()["FOO"])
- require.NoError(t, os.Setenv("FOO", "foo"))
+ t.Setenv("FOO", "foo")
assert.Equal(t, c.Env()["FOO"], "foo")
}
@@ -42,8 +42,7 @@ func TestCreateContext(t *testing.T) {
".": {URL: ub},
},
}
- os.Setenv("foo", "foo: bar")
- defer os.Unsetenv("foo")
+ t.Setenv("foo", "foo: bar")
c, err = createTmplContext(ctx, []string{"foo"}, d)
require.NoError(t, err)
assert.IsType(t, &tmplctx{}, c)
@@ -51,8 +50,7 @@ func TestCreateContext(t *testing.T) {
ds := ((*tctx)["foo"]).(map[string]interface{})
assert.Equal(t, "bar", ds["foo"])
- os.Setenv("bar", "bar: baz")
- defer os.Unsetenv("bar")
+ t.Setenv("bar", "bar: baz")
c, err = createTmplContext(ctx, []string{"."}, d)
require.NoError(t, err)
assert.IsType(t, map[string]interface{}{}, c)
diff --git a/data/datasource_blob_test.go b/data/datasource_blob_test.go
index c79780d0..6be8ea00 100644
--- a/data/datasource_blob_test.go
+++ b/data/datasource_blob_test.go
@@ -5,7 +5,6 @@ import (
"context"
"net/http/httptest"
"net/url"
- "os"
"testing"
"github.com/johannesboyne/gofakes3"
@@ -64,50 +63,48 @@ func TestReadBlob(t *testing.T) {
ts, u := setupTestBucket(t)
defer ts.Close()
- os.Setenv("AWS_ANON", "true")
- defer os.Unsetenv("AWS_ANON")
+ t.Run("no authentication", func(t *testing.T) {
+ t.Setenv("AWS_ANON", "true")
- d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
- require.NoError(t, err)
+ d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
+ require.NoError(t, err)
- var expected interface{}
- expected = "hello"
- out, err := d.Datasource("data")
- require.NoError(t, err)
- assert.Equal(t, expected, out)
+ expected := "hello"
+ out, err := d.Datasource("data")
+ require.NoError(t, err)
+ assert.Equal(t, expected, out)
+ })
- os.Unsetenv("AWS_ANON")
+ t.Run("with authentication", func(t *testing.T) {
+ t.Setenv("AWS_ACCESS_KEY_ID", "fake")
+ t.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
+ t.Setenv("AWS_S3_ENDPOINT", u.Host)
- os.Setenv("AWS_ACCESS_KEY_ID", "fake")
- os.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
- defer os.Unsetenv("AWS_ACCESS_KEY_ID")
- defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
- os.Setenv("AWS_S3_ENDPOINT", u.Host)
- defer os.Unsetenv("AWS_S3_ENDPOINT")
+ d, err := NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
+ require.NoError(t, err)
- d, err = NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
- require.NoError(t, err)
+ var expected interface{}
+ expected = map[string]interface{}{"value": "goodbye world"}
+ out, err := d.Datasource("data")
+ require.NoError(t, err)
+ assert.Equal(t, expected, out)
- expected = map[string]interface{}{"value": "goodbye world"}
- out, err = d.Datasource("data")
- require.NoError(t, err)
- assert.Equal(t, expected, out)
+ d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
+ require.NoError(t, err)
- d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
- require.NoError(t, err)
+ expected = []interface{}{"dir1/", "file1", "file2", "file3"}
+ out, err = d.Datasource("data")
+ require.NoError(t, err)
+ assert.EqualValues(t, expected, out)
- expected = []interface{}{"dir1/", "file1", "file2", "file3"}
- out, err = d.Datasource("data")
- require.NoError(t, err)
- assert.EqualValues(t, expected, out)
-
- d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
- require.NoError(t, err)
+ d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
+ require.NoError(t, err)
- expected = []interface{}{"file1", "file2"}
- out, err = d.Datasource("data")
- require.NoError(t, err)
- assert.EqualValues(t, expected, out)
+ expected = []interface{}{"file1", "file2"}
+ out, err = d.Datasource("data")
+ require.NoError(t, err)
+ assert.EqualValues(t, expected, out)
+ })
}
func TestBlobURL(t *testing.T) {
diff --git a/data/datasource_env_test.go b/data/datasource_env_test.go
index 7a10b347..6512578c 100644
--- a/data/datasource_env_test.go
+++ b/data/datasource_env_test.go
@@ -3,7 +3,6 @@ package data
import (
"context"
"net/url"
- "os"
"testing"
"github.com/stretchr/testify/assert"
@@ -19,10 +18,8 @@ func TestReadEnv(t *testing.T) {
ctx := context.Background()
content := []byte(`hello world`)
- os.Setenv("HELLO_WORLD", "hello world")
- defer os.Unsetenv("HELLO_WORLD")
- os.Setenv("HELLO_UNIVERSE", "hello universe")
- defer os.Unsetenv("HELLO_UNIVERSE")
+ t.Setenv("HELLO_WORLD", "hello world")
+ t.Setenv("HELLO_UNIVERSE", "hello universe")
source := &Source{Alias: "foo", URL: mustParseURL("env:HELLO_WORLD")}
diff --git a/data/datasource_git_test.go b/data/datasource_git_test.go
index 3f74a620..3b187ecc 100644
--- a/data/datasource_git_test.go
+++ b/data/datasource_git_test.go
@@ -484,15 +484,13 @@ func TestGitAuth(t *testing.T) {
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)
- os.Setenv("GIT_HTTP_PASSWORD", "swordfish")
- defer os.Unsetenv("GIT_HTTP_PASSWORD")
+ t.Setenv("GIT_HTTP_PASSWORD", "swordfish")
a, err = g.auth(mustParseURL("git+https://user@example.com/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)
os.Unsetenv("GIT_HTTP_PASSWORD")
- os.Setenv("GIT_HTTP_TOKEN", "mytoken")
- defer os.Unsetenv("GIT_HTTP_TOKEN")
+ t.Setenv("GIT_HTTP_TOKEN", "mytoken")
a, err = g.auth(mustParseURL("git+https://user@example.com/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.TokenAuth{Token: "mytoken"}, a)
@@ -508,25 +506,26 @@ func TestGitAuth(t *testing.T) {
assert.Equal(t, "git", sa.User)
}
- key := string(testdata.PEMBytes["ed25519"])
- os.Setenv("GIT_SSH_KEY", key)
- defer os.Unsetenv("GIT_SSH_KEY")
- a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo"))
- assert.NilError(t, err)
- ka, ok := a.(*ssh.PublicKeys)
- assert.Equal(t, true, ok)
- assert.Equal(t, "git", ka.User)
- os.Unsetenv("GIT_SSH_KEY")
-
- key = base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
- os.Setenv("GIT_SSH_KEY", key)
- defer os.Unsetenv("GIT_SSH_KEY")
- a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo"))
- assert.NilError(t, err)
- ka, ok = a.(*ssh.PublicKeys)
- assert.Equal(t, true, ok)
- assert.Equal(t, "git", ka.User)
- os.Unsetenv("GIT_SSH_KEY")
+ t.Run("plain string ed25519", func(t *testing.T) {
+ key := string(testdata.PEMBytes["ed25519"])
+ t.Setenv("GIT_SSH_KEY", key)
+ a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo"))
+ assert.NilError(t, err)
+ ka, ok := a.(*ssh.PublicKeys)
+ assert.Equal(t, true, ok)
+ assert.Equal(t, "git", ka.User)
+ })
+
+ t.Run("base64 ed25519", func(t *testing.T) {
+ key := base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
+ t.Setenv("GIT_SSH_KEY", key)
+ a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo"))
+ assert.NilError(t, err)
+ ka, ok := a.(*ssh.PublicKeys)
+ assert.Equal(t, true, ok)
+ assert.Equal(t, "git", ka.User)
+ os.Unsetenv("GIT_SSH_KEY")
+ })
}
func TestRefFromURL(t *testing.T) {
diff --git a/internal/cmd/config_test.go b/internal/cmd/config_test.go
index f4c032c8..0cac9983 100644
--- a/internal/cmd/config_test.go
+++ b/internal/cmd/config_test.go
@@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net/url"
- "os"
"testing"
"testing/fstest"
"time"
@@ -167,32 +166,38 @@ func TestPickConfigFile(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().String("config", defaultConfigFile, "foo")
- cf, req := pickConfigFile(cmd)
- assert.False(t, req)
- assert.Equal(t, defaultConfigFile, cf)
-
- os.Setenv("GOMPLATE_CONFIG", "foo.yaml")
- defer os.Unsetenv("GOMPLATE_CONFIG")
- cf, req = pickConfigFile(cmd)
- assert.True(t, req)
- assert.Equal(t, "foo.yaml", cf)
-
- cmd.ParseFlags([]string{"--config", "config.file"})
- cf, req = pickConfigFile(cmd)
- assert.True(t, req)
- assert.Equal(t, "config.file", cf)
-
- os.Setenv("GOMPLATE_CONFIG", "ignored.yaml")
- cf, req = pickConfigFile(cmd)
- assert.True(t, req)
- assert.Equal(t, "config.file", cf)
+ t.Run("default", func(t *testing.T) {
+ cf, req := pickConfigFile(cmd)
+ assert.False(t, req)
+ assert.Equal(t, defaultConfigFile, cf)
+ })
+
+ t.Run("GOMPLATE_CONFIG env var", func(t *testing.T) {
+ t.Setenv("GOMPLATE_CONFIG", "foo.yaml")
+ cf, req := pickConfigFile(cmd)
+ assert.True(t, req)
+ assert.Equal(t, "foo.yaml", cf)
+ })
+
+ t.Run("--config flag", func(t *testing.T) {
+ cmd.ParseFlags([]string{"--config", "config.file"})
+ cf, req := pickConfigFile(cmd)
+ assert.True(t, req)
+ assert.Equal(t, "config.file", cf)
+
+ t.Setenv("GOMPLATE_CONFIG", "ignored.yaml")
+ cf, req = pickConfigFile(cmd)
+ assert.True(t, req)
+ assert.Equal(t, "config.file", cf)
+ })
}
func TestApplyEnvVars(t *testing.T) {
- os.Setenv("GOMPLATE_PLUGIN_TIMEOUT", "bogus")
- _, err := applyEnvVars(context.Background(), &config.Config{})
- os.Unsetenv("GOMPLATE_PLUGIN_TIMEOUT")
- assert.Error(t, err)
+ t.Run("invalid GOMPLATE_PLUGIN_TIMEOUT", func(t *testing.T) {
+ t.Setenv("GOMPLATE_PLUGIN_TIMEOUT", "bogus")
+ _, err := applyEnvVars(context.Background(), &config.Config{})
+ assert.Error(t, err)
+ })
data := []struct {
input, expected *config.Config
@@ -274,10 +279,9 @@ func TestApplyEnvVars(t *testing.T) {
for i, d := range data {
d := d
t.Run(fmt.Sprintf("applyEnvVars_%s_%s/%d", d.env, d.value, i), func(t *testing.T) {
- os.Setenv(d.env, d.value)
+ t.Setenv(d.env, d.value)
actual, err := applyEnvVars(context.Background(), d.input)
- os.Unsetenv(d.env)
require.NoError(t, err)
assert.EqualValues(t, d.expected, actual)
})
diff --git a/internal/cmd/logger_test.go b/internal/cmd/logger_test.go
index 72f8f292..ea595c1e 100644
--- a/internal/cmd/logger_test.go
+++ b/internal/cmd/logger_test.go
@@ -14,13 +14,12 @@ import (
func TestLogFormat(t *testing.T) {
os.Unsetenv("GOMPLATE_LOG_FORMAT")
- defer os.Unsetenv("GOMPLATE_LOG_FORMAT")
assert.Equal(t, "json", logFormat(nil))
// os.Stdout isn't a terminal when this runs as a unit test...
assert.Equal(t, "json", logFormat(os.Stdout))
- os.Setenv("GOMPLATE_LOG_FORMAT", "simple")
+ t.Setenv("GOMPLATE_LOG_FORMAT", "simple")
assert.Equal(t, "simple", logFormat(os.Stdout))
assert.Equal(t, "simple", logFormat(&bytes.Buffer{}))
}
diff --git a/libkv/consul_test.go b/libkv/consul_test.go
index 74d78d01..5f128aa7 100644
--- a/libkv/consul_test.go
+++ b/libkv/consul_test.go
@@ -14,68 +14,83 @@ import (
)
func TestConsulURL(t *testing.T) {
- defer os.Unsetenv("CONSUL_HTTP_SSL")
- os.Setenv("CONSUL_HTTP_SSL", "true")
-
- u, _ := url.Parse("consul://")
- expected := &url.URL{Host: "localhost:8500", Scheme: "https"}
- actual, err := consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- u, _ = url.Parse("consul+http://myconsul.server")
- expected = &url.URL{Host: "myconsul.server", Scheme: "http"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- os.Setenv("CONSUL_HTTP_SSL", "false")
- u, _ = url.Parse("consul+https://myconsul.server:1234")
- expected = &url.URL{Host: "myconsul.server:1234", Scheme: "https"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
os.Unsetenv("CONSUL_HTTP_SSL")
- u, _ = url.Parse("consul://myconsul.server:2345")
- expected = &url.URL{Host: "myconsul.server:2345", Scheme: "http"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- u, _ = url.Parse("consul://myconsul.server:3456/foo/bar/baz")
- expected = &url.URL{Host: "myconsul.server:3456", Scheme: "http"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- defer os.Unsetenv("CONSUL_HTTP_ADDR")
- os.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500")
-
- // given URL takes precedence over env var
- expected = &url.URL{Host: "myconsul.server:3456", Scheme: "http"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- u, _ = url.Parse("consul://")
-
- defer os.Unsetenv("CONSUL_HTTP_SSL")
- os.Setenv("CONSUL_HTTP_SSL", "true")
-
- // TLS enabled, HTTP_ADDR is set, URL has no host and ambiguous scheme
- expected = &url.URL{Host: "foo:8500", Scheme: "https"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
-
- defer os.Unsetenv("CONSUL_HTTP_ADDR")
- os.Setenv("CONSUL_HTTP_ADDR", "localhost:8501")
- expected = &url.URL{Host: "localhost:8501", Scheme: "https"}
- actual, err = consulURL(u)
- require.NoError(t, err)
- assert.Equal(t, expected, actual)
+ t.Run("consul scheme, CONSUL_HTTP_SSL set to true", func(t *testing.T) {
+ t.Setenv("CONSUL_HTTP_SSL", "true")
+
+ u, _ := url.Parse("consul://")
+ expected := &url.URL{Host: "localhost:8500", Scheme: "https"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("consul+http scheme", func(t *testing.T) {
+ u, _ := url.Parse("consul+http://myconsul.server")
+ expected := &url.URL{Host: "myconsul.server", Scheme: "http"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("consul+https scheme, CONSUL_HTTP_SSL set to false", func(t *testing.T) {
+ t.Setenv("CONSUL_HTTP_SSL", "false")
+
+ u, _ := url.Parse("consul+https://myconsul.server:1234")
+ expected := &url.URL{Host: "myconsul.server:1234", Scheme: "https"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("consul scheme, CONSUL_HTTP_SSL unset", func(t *testing.T) {
+ u, _ := url.Parse("consul://myconsul.server:2345")
+ expected := &url.URL{Host: "myconsul.server:2345", Scheme: "http"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("consul scheme, ignore path", func(t *testing.T) {
+ u, _ := url.Parse("consul://myconsul.server:3456/foo/bar/baz")
+ expected := &url.URL{Host: "myconsul.server:3456", Scheme: "http"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("given URL takes precedence over env var", func(t *testing.T) {
+ t.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500")
+
+ u, _ := url.Parse("consul://myconsul.server:3456/foo/bar/baz")
+ expected := &url.URL{Host: "myconsul.server:3456", Scheme: "http"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("TLS enabled, HTTP_ADDR is set, URL has no host and ambiguous scheme", func(t *testing.T) {
+ t.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500")
+ t.Setenv("CONSUL_HTTP_SSL", "true")
+
+ u, _ := url.Parse("consul://")
+ expected := &url.URL{Host: "foo:8500", Scheme: "https"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
+
+ t.Run("TLS enabled, HTTP_ADDR is set without scheme, URL has no host and ambiguous scheme", func(t *testing.T) {
+ t.Setenv("CONSUL_HTTP_ADDR", "localhost:8501")
+ t.Setenv("CONSUL_HTTP_SSL", "true")
+
+ u, _ := url.Parse("consul://")
+ expected := &url.URL{Host: "localhost:8501", Scheme: "https"}
+ actual, err := consulURL(u)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+ })
}
func TestConsulAddrFromEnv(t *testing.T) {
@@ -106,55 +121,56 @@ func TestSetupTLS(t *testing.T) {
KeyFile: "keyfile",
}
- defer os.Unsetenv("CONSUL_TLS_SERVER_NAME")
- defer os.Unsetenv("CONSUL_CACERT")
- defer os.Unsetenv("CONSUL_CAPATH")
- defer os.Unsetenv("CONSUL_CLIENT_CERT")
- defer os.Unsetenv("CONSUL_CLIENT_KEY")
- os.Setenv("CONSUL_TLS_SERVER_NAME", expected.Address)
- os.Setenv("CONSUL_CACERT", expected.CAFile)
- os.Setenv("CONSUL_CAPATH", expected.CAPath)
- os.Setenv("CONSUL_CLIENT_CERT", expected.CertFile)
- os.Setenv("CONSUL_CLIENT_KEY", expected.KeyFile)
-
- assert.Equal(t, expected, setupTLS())
+ t.Setenv("CONSUL_TLS_SERVER_NAME", expected.Address)
+ t.Setenv("CONSUL_CACERT", expected.CAFile)
+ t.Setenv("CONSUL_CAPATH", expected.CAPath)
+ t.Setenv("CONSUL_CLIENT_CERT", expected.CertFile)
+ t.Setenv("CONSUL_CLIENT_KEY", expected.KeyFile)
- expected.InsecureSkipVerify = false
- defer os.Unsetenv("CONSUL_HTTP_SSL_VERIFY")
- os.Setenv("CONSUL_HTTP_SSL_VERIFY", "true")
assert.Equal(t, expected, setupTLS())
- expected.InsecureSkipVerify = true
- os.Setenv("CONSUL_HTTP_SSL_VERIFY", "false")
- assert.Equal(t, expected, setupTLS())
+ t.Run("CONSUL_HTTP_SSL_VERIFY is true", func(t *testing.T) {
+ expected.InsecureSkipVerify = false
+ t.Setenv("CONSUL_HTTP_SSL_VERIFY", "true")
+ assert.Equal(t, expected, setupTLS())
+ })
+
+ t.Run("CONSUL_HTTP_SSL_VERIFY is false", func(t *testing.T) {
+ expected.InsecureSkipVerify = true
+ t.Setenv("CONSUL_HTTP_SSL_VERIFY", "false")
+ assert.Equal(t, expected, setupTLS())
+ })
}
func TestConsulConfig(t *testing.T) {
- expectedConfig := &store.Config{}
-
- actualConfig, err := consulConfig(false)
- require.NoError(t, err)
-
- assert.Equal(t, expectedConfig, actualConfig)
-
- defer os.Unsetenv("CONSUL_TIMEOUT")
- os.Setenv("CONSUL_TIMEOUT", "10")
- expectedConfig = &store.Config{
- ConnectionTimeout: 10 * time.Second,
- }
-
- actualConfig, err = consulConfig(false)
- require.NoError(t, err)
- assert.Equal(t, expectedConfig, actualConfig)
-
- os.Unsetenv("CONSUL_TIMEOUT")
- expectedConfig = &store.Config{
- TLS: &tls.Config{MinVersion: tls.VersionTLS13},
- }
-
- actualConfig, err = consulConfig(true)
- require.NoError(t, err)
- assert.NotNil(t, actualConfig.TLS)
- actualConfig.TLS = &tls.Config{MinVersion: tls.VersionTLS13}
- assert.Equal(t, expectedConfig, actualConfig)
+ t.Run("default ", func(t *testing.T) {
+ expectedConfig := &store.Config{}
+
+ actualConfig, err := consulConfig(false)
+ require.NoError(t, err)
+
+ assert.Equal(t, expectedConfig, actualConfig)
+ })
+
+ t.Run("with CONSUL_TIMEOUT", func(t *testing.T) {
+ t.Setenv("CONSUL_TIMEOUT", "10")
+ expectedConfig := &store.Config{
+ ConnectionTimeout: 10 * time.Second,
+ }
+
+ actualConfig, err := consulConfig(false)
+ require.NoError(t, err)
+ assert.Equal(t, expectedConfig, actualConfig)
+ })
+
+ t.Run("with TLS", func(t *testing.T) {
+ expectedConfig := &store.Config{
+ TLS: &tls.Config{MinVersion: tls.VersionTLS13},
+ }
+ actualConfig, err := consulConfig(true)
+ require.NoError(t, err)
+ assert.NotNil(t, actualConfig.TLS)
+ actualConfig.TLS = &tls.Config{MinVersion: tls.VersionTLS13}
+ assert.Equal(t, expectedConfig, actualConfig)
+ })
}
diff --git a/render_test.go b/render_test.go
index 21b0be74..684367fa 100644
--- a/render_test.go
+++ b/render_test.go
@@ -37,8 +37,7 @@ func TestRenderTemplate(t *testing.T) {
hu, _ := url.Parse("stdin:")
wu, _ := url.Parse("env:WORLD")
- os.Setenv("WORLD", "world")
- defer os.Unsetenv("WORLD")
+ t.Setenv("WORLD", "world")
tr = NewRenderer(Options{
Context: map[string]Datasource{
diff --git a/vault/auth_test.go b/vault/auth_test.go
index 552db572..fa8f29cc 100644
--- a/vault/auth_test.go
+++ b/vault/auth_test.go
@@ -1,7 +1,6 @@
package vault
import (
- "os"
"testing"
"github.com/stretchr/testify/assert"
@@ -11,8 +10,7 @@ import (
func TestLogin(t *testing.T) {
server, v := MockServer(404, "Not Found")
defer server.Close()
- os.Setenv("VAULT_TOKEN", "foo")
- defer os.Unsetenv("VAULT_TOKEN")
+ t.Setenv("VAULT_TOKEN", "foo")
v.Login()
assert.Equal(t, "foo", v.client.Token())
}
@@ -20,8 +18,7 @@ func TestLogin(t *testing.T) {
func TestTokenLogin(t *testing.T) {
server, v := MockServer(404, "Not Found")
defer server.Close()
- os.Setenv("VAULT_TOKEN", "foo")
- defer os.Unsetenv("VAULT_TOKEN")
+ t.Setenv("VAULT_TOKEN", "foo")
token, err := v.TokenLogin()
require.NoError(t, err)
diff --git a/vault/vault_test.go b/vault/vault_test.go
index 754e2d4c..1cc5e31d 100644
--- a/vault/vault_test.go
+++ b/vault/vault_test.go
@@ -15,8 +15,7 @@ func TestNew(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "https://127.0.0.1:8200", v.client.Address())
- os.Setenv("VAULT_ADDR", "http://example.com:1234")
- defer os.Unsetenv("VAULT_ADDR")
+ t.Setenv("VAULT_ADDR", "http://example.com:1234")
v, err = New(nil)
require.NoError(t, err)
assert.Equal(t, "http://example.com:1234", v.client.Address())