diff options
| author | Dave Henderson <dhenderson@gmail.com> | 2019-01-23 21:14:49 -0500 |
|---|---|---|
| committer | Dave Henderson <dhenderson@gmail.com> | 2019-01-23 22:15:01 -0500 |
| commit | 1bef25c9000072bff4f1da2cf6a69ec4dcb68f96 (patch) | |
| tree | 68e58282219addbabeab9a06d1a4ec852708e397 /aws | |
| parent | bfa0628a23b0b393cb65aa2dbfb02e7a2c2db4aa (diff) | |
AWS region detection for SSM param store datasources
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'aws')
| -rw-r--r-- | aws/ec2info.go | 65 | ||||
| -rw-r--r-- | aws/ec2info_test.go | 66 | ||||
| -rw-r--r-- | aws/ec2meta.go | 5 |
3 files changed, 114 insertions, 22 deletions
diff --git a/aws/ec2info.go b/aws/ec2info.go index c5199449..3dcf37b6 100644 --- a/aws/ec2info.go +++ b/aws/ec2info.go @@ -22,10 +22,6 @@ var ( sdkSessionInit sync.Once ) -const ( - unknown = "unknown" -) - // ClientOptions - type ClientOptions struct { Timeout time.Duration @@ -48,14 +44,16 @@ type InstanceDescriber interface { func GetClientOptions() ClientOptions { coInit.Do(func() { timeout := os.Getenv("AWS_TIMEOUT") - if timeout != "" { - t, err := strconv.Atoi(timeout) - if err != nil { - panic(errors.Wrapf(err, "Invalid AWS_TIMEOUT value '%s' - must be an integer\n", timeout)) - } + if timeout == "" { + timeout = "500" + } - co.Timeout = time.Duration(t) * time.Millisecond + t, err := strconv.Atoi(timeout) + if err != nil { + panic(errors.Wrapf(err, "Invalid AWS_TIMEOUT value '%s' - must be an integer\n", timeout)) } + + co.Timeout = time.Duration(t) * time.Millisecond }) return co } @@ -72,16 +70,17 @@ func SDKSession(region ...string) *session.Session { config := aws.NewConfig() config = config.WithHTTPClient(&http.Client{Timeout: timeout}) - metaRegion := unknown + metaRegion := "" if len(region) > 0 { metaRegion = region[0] + } else { + var err error + metaRegion, err = getRegion() + if err != nil { + panic(errors.Wrap(err, "failed to determine EC2 region")) + } } - // Waiting for https://github.com/aws/aws-sdk-go/issues/1103 - _, default1 := os.LookupEnv("AWS_REGION") - _, default2 := os.LookupEnv("AWS_DEFAULT_REGION") - if metaRegion != unknown && !default1 && !default2 { - config = config.WithRegion(metaRegion) - } + config = config.WithRegion(metaRegion) sdkSession = session.Must(session.NewSessionWithOptions(session.Options{ Config: *config, @@ -91,17 +90,39 @@ func SDKSession(region ...string) *session.Session { return sdkSession } +// Attempts to get the EC2 region to use. If we're running on an EC2 Instance +// and neither AWS_REGION nor AWS_DEFAULT_REGION are set, we'll infer from EC2 +// metadata. +// Once https://github.com/aws/aws-sdk-go/issues/1103 is resolve this should be +// tidier! +func getRegion(m ...*Ec2Meta) (string, error) { + region := "" + _, default1 := os.LookupEnv("AWS_REGION") + _, default2 := os.LookupEnv("AWS_DEFAULT_REGION") + if !default1 && !default2 { + // Maybe we're in EC2, let's try to read metadata + var metaClient *Ec2Meta + if len(m) > 0 { + metaClient = m[0] + } else { + metaClient = NewEc2Meta(GetClientOptions()) + } + var err error + region, err = metaClient.Region() + if err != nil { + return "", errors.Wrap(err, "failed to determine EC2 region") + } + } + return region, nil +} + // NewEc2Info - func NewEc2Info(options ClientOptions) (info *Ec2Info) { metaClient := NewEc2Meta(options) return &Ec2Info{ describer: func() (InstanceDescriber, error) { if describerClient == nil { - metaRegion, err := metaClient.Region() - if err != nil { - return nil, errors.Wrap(err, "failed to determine EC2 region") - } - session := SDKSession(metaRegion) + session := SDKSession() describerClient = ec2.New(session) } return describerClient, nil diff --git a/aws/ec2info_test.go b/aws/ec2info_test.go index cd74de93..809704ff 100644 --- a/aws/ec2info_test.go +++ b/aws/ec2info_test.go @@ -1,7 +1,10 @@ package aws import ( + "os" + "sync" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" @@ -103,3 +106,66 @@ func TestNewEc2Info(t *testing.T) { assert.Equal(t, "bar", must(e.Tag("foo"))) assert.Equal(t, "bar", must(e.Tag("foo", "default"))) } + +func TestGetRegion(t *testing.T) { + oldReg, ok := os.LookupEnv("AWS_REGION") + if ok { + defer os.Setenv("AWS_REGION", oldReg) + } + oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION") + if ok { + defer os.Setenv("AWS_REGION", oldDefReg) + } + + os.Setenv("AWS_REGION", "kalamazoo") + os.Unsetenv("AWS_DEFAULT_REGION") + region, err := getRegion() + assert.NoError(t, err) + assert.Empty(t, region) + + os.Setenv("AWS_DEFAULT_REGION", "kalamazoo") + os.Unsetenv("AWS_REGION") + region, err = getRegion() + assert.NoError(t, err) + assert.Empty(t, region) + + os.Unsetenv("AWS_DEFAULT_REGION") + metaClient := NewDummyEc2Meta() + region, err = getRegion(metaClient) + assert.NoError(t, err) + assert.Equal(t, "unknown", region) + + server, ec2meta := MockServer(200, `{"region":"us-east-1"}`) + defer server.Close() + region, err = getRegion(ec2meta) + assert.NoError(t, err) + assert.Equal(t, "us-east-1", region) +} + +func TestGetClientOptions(t *testing.T) { + oldVar, ok := os.LookupEnv("AWS_TIMEOUT") + if ok { + defer os.Setenv("AWS_TIMEOUT", oldVar) + } + + co := GetClientOptions() + assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co) + + os.Setenv("AWS_TIMEOUT", "42") + // reset the Once + coInit = sync.Once{} + co = GetClientOptions() + assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) + + os.Setenv("AWS_TIMEOUT", "123") + // without resetting the Once, expect to be reused + co = GetClientOptions() + assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) + + os.Setenv("AWS_TIMEOUT", "foo") + // reset the Once + coInit = sync.Once{} + assert.Panics(t, func() { + GetClientOptions() + }) +} diff --git a/aws/ec2meta.go b/aws/ec2meta.go index 901cfcff..61a78b23 100644 --- a/aws/ec2meta.go +++ b/aws/ec2meta.go @@ -15,6 +15,11 @@ import ( // DefaultEndpoint - var DefaultEndpoint = "http://169.254.169.254" +const ( + // the default region + unknown = "unknown" +) + // Ec2Meta - type Ec2Meta struct { Endpoint string |
