summaryrefslogtreecommitdiff
path: root/aws
diff options
context:
space:
mode:
authorDave Henderson <dhenderson@gmail.com>2019-01-23 21:14:49 -0500
committerDave Henderson <dhenderson@gmail.com>2019-01-23 22:15:01 -0500
commit1bef25c9000072bff4f1da2cf6a69ec4dcb68f96 (patch)
tree68e58282219addbabeab9a06d1a4ec852708e397 /aws
parentbfa0628a23b0b393cb65aa2dbfb02e7a2c2db4aa (diff)
AWS region detection for SSM param store datasources
Signed-off-by: Dave Henderson <dhenderson@gmail.com>
Diffstat (limited to 'aws')
-rw-r--r--aws/ec2info.go65
-rw-r--r--aws/ec2info_test.go66
-rw-r--r--aws/ec2meta.go5
3 files changed, 114 insertions, 22 deletions
diff --git a/aws/ec2info.go b/aws/ec2info.go
index c5199449..3dcf37b6 100644
--- a/aws/ec2info.go
+++ b/aws/ec2info.go
@@ -22,10 +22,6 @@ var (
sdkSessionInit sync.Once
)
-const (
- unknown = "unknown"
-)
-
// ClientOptions -
type ClientOptions struct {
Timeout time.Duration
@@ -48,14 +44,16 @@ type InstanceDescriber interface {
func GetClientOptions() ClientOptions {
coInit.Do(func() {
timeout := os.Getenv("AWS_TIMEOUT")
- if timeout != "" {
- t, err := strconv.Atoi(timeout)
- if err != nil {
- panic(errors.Wrapf(err, "Invalid AWS_TIMEOUT value '%s' - must be an integer\n", timeout))
- }
+ if timeout == "" {
+ timeout = "500"
+ }
- co.Timeout = time.Duration(t) * time.Millisecond
+ t, err := strconv.Atoi(timeout)
+ if err != nil {
+ panic(errors.Wrapf(err, "Invalid AWS_TIMEOUT value '%s' - must be an integer\n", timeout))
}
+
+ co.Timeout = time.Duration(t) * time.Millisecond
})
return co
}
@@ -72,16 +70,17 @@ func SDKSession(region ...string) *session.Session {
config := aws.NewConfig()
config = config.WithHTTPClient(&http.Client{Timeout: timeout})
- metaRegion := unknown
+ metaRegion := ""
if len(region) > 0 {
metaRegion = region[0]
+ } else {
+ var err error
+ metaRegion, err = getRegion()
+ if err != nil {
+ panic(errors.Wrap(err, "failed to determine EC2 region"))
+ }
}
- // Waiting for https://github.com/aws/aws-sdk-go/issues/1103
- _, default1 := os.LookupEnv("AWS_REGION")
- _, default2 := os.LookupEnv("AWS_DEFAULT_REGION")
- if metaRegion != unknown && !default1 && !default2 {
- config = config.WithRegion(metaRegion)
- }
+ config = config.WithRegion(metaRegion)
sdkSession = session.Must(session.NewSessionWithOptions(session.Options{
Config: *config,
@@ -91,17 +90,39 @@ func SDKSession(region ...string) *session.Session {
return sdkSession
}
+// Attempts to get the EC2 region to use. If we're running on an EC2 Instance
+// and neither AWS_REGION nor AWS_DEFAULT_REGION are set, we'll infer from EC2
+// metadata.
+// Once https://github.com/aws/aws-sdk-go/issues/1103 is resolve this should be
+// tidier!
+func getRegion(m ...*Ec2Meta) (string, error) {
+ region := ""
+ _, default1 := os.LookupEnv("AWS_REGION")
+ _, default2 := os.LookupEnv("AWS_DEFAULT_REGION")
+ if !default1 && !default2 {
+ // Maybe we're in EC2, let's try to read metadata
+ var metaClient *Ec2Meta
+ if len(m) > 0 {
+ metaClient = m[0]
+ } else {
+ metaClient = NewEc2Meta(GetClientOptions())
+ }
+ var err error
+ region, err = metaClient.Region()
+ if err != nil {
+ return "", errors.Wrap(err, "failed to determine EC2 region")
+ }
+ }
+ return region, nil
+}
+
// NewEc2Info -
func NewEc2Info(options ClientOptions) (info *Ec2Info) {
metaClient := NewEc2Meta(options)
return &Ec2Info{
describer: func() (InstanceDescriber, error) {
if describerClient == nil {
- metaRegion, err := metaClient.Region()
- if err != nil {
- return nil, errors.Wrap(err, "failed to determine EC2 region")
- }
- session := SDKSession(metaRegion)
+ session := SDKSession()
describerClient = ec2.New(session)
}
return describerClient, nil
diff --git a/aws/ec2info_test.go b/aws/ec2info_test.go
index cd74de93..809704ff 100644
--- a/aws/ec2info_test.go
+++ b/aws/ec2info_test.go
@@ -1,7 +1,10 @@
package aws
import (
+ "os"
+ "sync"
"testing"
+ "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
@@ -103,3 +106,66 @@ func TestNewEc2Info(t *testing.T) {
assert.Equal(t, "bar", must(e.Tag("foo")))
assert.Equal(t, "bar", must(e.Tag("foo", "default")))
}
+
+func TestGetRegion(t *testing.T) {
+ oldReg, ok := os.LookupEnv("AWS_REGION")
+ if ok {
+ defer os.Setenv("AWS_REGION", oldReg)
+ }
+ oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION")
+ if ok {
+ defer os.Setenv("AWS_REGION", oldDefReg)
+ }
+
+ os.Setenv("AWS_REGION", "kalamazoo")
+ os.Unsetenv("AWS_DEFAULT_REGION")
+ region, err := getRegion()
+ assert.NoError(t, err)
+ assert.Empty(t, region)
+
+ os.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
+ os.Unsetenv("AWS_REGION")
+ region, err = getRegion()
+ assert.NoError(t, err)
+ assert.Empty(t, region)
+
+ os.Unsetenv("AWS_DEFAULT_REGION")
+ metaClient := NewDummyEc2Meta()
+ region, err = getRegion(metaClient)
+ assert.NoError(t, err)
+ assert.Equal(t, "unknown", region)
+
+ server, ec2meta := MockServer(200, `{"region":"us-east-1"}`)
+ defer server.Close()
+ region, err = getRegion(ec2meta)
+ assert.NoError(t, err)
+ assert.Equal(t, "us-east-1", region)
+}
+
+func TestGetClientOptions(t *testing.T) {
+ oldVar, ok := os.LookupEnv("AWS_TIMEOUT")
+ if ok {
+ defer os.Setenv("AWS_TIMEOUT", oldVar)
+ }
+
+ co := GetClientOptions()
+ assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co)
+
+ os.Setenv("AWS_TIMEOUT", "42")
+ // reset the Once
+ coInit = sync.Once{}
+ co = GetClientOptions()
+ assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
+
+ os.Setenv("AWS_TIMEOUT", "123")
+ // without resetting the Once, expect to be reused
+ co = GetClientOptions()
+ assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
+
+ os.Setenv("AWS_TIMEOUT", "foo")
+ // reset the Once
+ coInit = sync.Once{}
+ assert.Panics(t, func() {
+ GetClientOptions()
+ })
+}
diff --git a/aws/ec2meta.go b/aws/ec2meta.go
index 901cfcff..61a78b23 100644
--- a/aws/ec2meta.go
+++ b/aws/ec2meta.go
@@ -15,6 +15,11 @@ import (
// DefaultEndpoint -
var DefaultEndpoint = "http://169.254.169.254"
+const (
+ // the default region
+ unknown = "unknown"
+)
+
// Ec2Meta -
type Ec2Meta struct {
Endpoint string