From 8f6cd731a77b015bee4c98e631363db02cc7063a Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Tue, 31 May 2022 01:36:12 +0530 Subject: Add support for aws imdsv2 (#1402) * Use ec2metadata package from aws-sdk This will handle IMDSv2 etc tranparently (and as well any future changes) * Fix linter errors * Fix dynamic data retrieval * Fix vault ec2 auth test failures --- aws/ec2info_test.go | 22 ++++----- aws/ec2meta.go | 136 +++++++++++++++++++++++++++------------------------- aws/ec2meta_test.go | 21 +++----- aws/testutils.go | 72 ++++++++++++++++------------ gomplate_test.go | 20 ++++---- 5 files changed, 136 insertions(+), 135 deletions(-) diff --git a/aws/ec2info_test.go b/aws/ec2info_test.go index b3490ac5..b2ca9336 100644 --- a/aws/ec2info_test.go +++ b/aws/ec2info_test.go @@ -12,8 +12,8 @@ import ( ) func TestTag_MissingKey(t *testing.T) { - server, ec2meta := MockServer(200, `"i-1234"`) - defer server.Close() + ec2meta := MockEC2Meta(map[string]string{"instance-id": "i-1234"}, nil, "") + client := DummyInstanceDescriber{ tags: []*ec2.Tag{ { @@ -39,8 +39,8 @@ func TestTag_MissingKey(t *testing.T) { } func TestTag_ValidKey(t *testing.T) { - server, ec2meta := MockServer(200, `"i-1234"`) - defer server.Close() + ec2meta := MockEC2Meta(map[string]string{"instance-id": "i-1234"}, nil, "") + client := DummyInstanceDescriber{ tags: []*ec2.Tag{ { @@ -66,8 +66,7 @@ func TestTag_ValidKey(t *testing.T) { } func TestTags(t *testing.T) { - server, ec2meta := MockServer(200, `"i-1234"`) - defer server.Close() + ec2meta := MockEC2Meta(map[string]string{"instance-id": "i-1234"}, nil, "") client := DummyInstanceDescriber{ tags: []*ec2.Tag{ { @@ -92,9 +91,9 @@ func TestTags(t *testing.T) { } func TestTag_NonEC2(t *testing.T) { - server, ec2meta := MockServer(404, "") + ec2meta := MockEC2Meta(nil, nil, "") ec2meta.nonAWS = true - defer server.Close() + client := DummyInstanceDescriber{} e := &Ec2Info{ describer: func() (InstanceDescriber, error) { @@ -109,8 +108,7 @@ func TestTag_NonEC2(t *testing.T) { } func TestNewEc2Info(t *testing.T) { - server, ec2meta := MockServer(200, `"i-1234"`) - defer server.Close() + ec2meta := MockEC2Meta(map[string]string{"instance-id": "i-1234"}, nil, "") client := DummyInstanceDescriber{ tags: []*ec2.Tag{ { @@ -161,8 +159,8 @@ func TestGetRegion(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "unknown", region) - server, ec2meta := MockServer(200, `{"region":"us-east-1"}`) - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "us-east-1") + region, err = getRegion(ec2meta) assert.NoError(t, err) assert.Equal(t, "us-east-1", region) diff --git a/aws/ec2meta.go b/aws/ec2meta.go index 4165214c..3a15726d 100644 --- a/aws/ec2meta.go +++ b/aws/ec2meta.go @@ -1,20 +1,15 @@ package aws import ( - "encoding/json" - "io/ioutil" "net/http" "strings" - "time" - - "github.com/pkg/errors" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" "github.com/hairyhenderson/gomplate/v3/env" ) -// DefaultEndpoint - -var DefaultEndpoint = "http://169.254.169.254" - const ( // the default region unknown = "unknown" @@ -22,20 +17,38 @@ const ( // Ec2Meta - type Ec2Meta struct { - Client *http.Client - cache map[string]string - Endpoint string - options ClientOptions - nonAWS bool + metadataCache map[string]string + dynamicdataCache map[string]string + ec2MetadataProvider func() (EC2Metadata, error) + nonAWS bool +} + +type EC2Metadata interface { + GetMetadata(p string) (string, error) + GetDynamicData(p string) (string, error) + Region() (string, error) } // NewEc2Meta - func NewEc2Meta(options ClientOptions) *Ec2Meta { - if endpoint := env.Getenv("AWS_META_ENDPOINT"); endpoint != "" { - DefaultEndpoint = endpoint + return &Ec2Meta{ + metadataCache: make(map[string]string), + dynamicdataCache: make(map[string]string), + ec2MetadataProvider: func() (EC2Metadata, error) { + config := aws.NewConfig() + config = config.WithHTTPClient(&http.Client{Timeout: options.Timeout}) + if endpoint := env.Getenv("AWS_META_ENDPOINT"); endpoint != "" { + config = config.WithEndpoint(endpoint) + } + + s, err := session.NewSession(config) + if err != nil { + return nil, err + } + + return ec2metadata.New(s), nil + }, } - - return &Ec2Meta{cache: make(map[string]string), options: options} } // returnDefault - @@ -56,10 +69,8 @@ func unreachable(err error) bool { return false } -// retrieve EC2 metadata, defaulting if we're not in EC2 or if there's a non-OK -// response. If there is an OK response, but we can't parse it, this errors -func (e *Ec2Meta) retrieveMetadata(url string, def ...string) (string, error) { - if value, ok := e.cache[url]; ok { +func (e *Ec2Meta) retrieveMetadata(key string, def ...string) (string, error) { + if value, ok := e.metadataCache[key]; ok { return value, nil } @@ -67,55 +78,57 @@ func (e *Ec2Meta) retrieveMetadata(url string, def ...string) (string, error) { return returnDefault(def), nil } - if e.Client == nil { - timeout := e.options.Timeout - if timeout == 0 { - timeout = 500 * time.Millisecond - } - e.Client = &http.Client{Timeout: timeout} + emd, err := e.ec2MetadataProvider() + if err != nil { + return "", err } - resp, err := e.Client.Get(url) + + value, err := emd.GetMetadata(key) if err != nil { if unreachable(err) { e.nonAWS = true } return returnDefault(def), nil } + e.metadataCache[key] = value - // nolint: errcheck - defer resp.Body.Close() - if resp.StatusCode > 399 { + return value, nil +} + +func (e *Ec2Meta) retrieveDynamicdata(key string, def ...string) (string, error) { + if value, ok := e.dynamicdataCache[key]; ok { + return value, nil + } + + if e.nonAWS { return returnDefault(def), nil } - body, err := ioutil.ReadAll(resp.Body) + emd, err := e.ec2MetadataProvider() if err != nil { - return "", errors.Wrapf(err, "Failed to read response body from %s", url) + return "", err } - value := strings.TrimSpace(string(body)) - e.cache[url] = value + + value, err := emd.GetDynamicData(key) + if err != nil { + if unreachable(err) { + e.nonAWS = true + } + return returnDefault(def), nil + } + e.dynamicdataCache[key] = value return value, nil } // Meta - func (e *Ec2Meta) Meta(key string, def ...string) (string, error) { - if e.Endpoint == "" { - e.Endpoint = DefaultEndpoint - } - - url := e.Endpoint + "/latest/meta-data/" + key - return e.retrieveMetadata(url, def...) + return e.retrieveMetadata(key, def...) } // Dynamic - func (e *Ec2Meta) Dynamic(key string, def ...string) (string, error) { - if e.Endpoint == "" { - e.Endpoint = DefaultEndpoint - } - - url := e.Endpoint + "/latest/dynamic/" + key - return e.retrieveMetadata(url, def...) + return e.retrieveDynamicdata(key, def...) } // Region - @@ -125,28 +138,19 @@ func (e *Ec2Meta) Region(def ...string) (string, error) { defaultRegion = unknown } - doc, err := e.Dynamic("instance-identity/document", `{"region":"`+defaultRegion+`"}`) + if e.nonAWS { + return defaultRegion, nil + } + + emd, err := e.ec2MetadataProvider() if err != nil { return "", err } - obj := &InstanceDocument{ - Region: defaultRegion, - } - err = json.Unmarshal([]byte(doc), &obj) - if err != nil { - return "", errors.Wrapf(err, "Unable to unmarshal JSON object %s", doc) + + region, err := emd.Region() + if err != nil || region == "" { + return defaultRegion, nil } - return obj.Region, nil -} -// InstanceDocument - -type InstanceDocument struct { - PrivateIP string `json:"privateIp"` - AvailabilityZone string `json:"availabilityZone"` - InstanceID string `json:"InstanceId"` - InstanceType string `json:"InstanceType"` - AccountID string `json:"AccountId"` - ImageID string `json:"imageId"` - Architecture string `json:"architecture"` - Region string `json:"region"` + return region, nil } diff --git a/aws/ec2meta_test.go b/aws/ec2meta_test.go index 39fe6449..732f97b6 100644 --- a/aws/ec2meta_test.go +++ b/aws/ec2meta_test.go @@ -15,54 +15,47 @@ func must(r interface{}, err error) interface{} { } func TestMeta_MissingKey(t *testing.T) { - server, ec2meta := MockServer(404, "") - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "") assert.Empty(t, must(ec2meta.Meta("foo"))) assert.Equal(t, "default", must(ec2meta.Meta("foo", "default"))) } func TestMeta_ValidKey(t *testing.T) { - server, ec2meta := MockServer(200, "i-1234") - defer server.Close() + ec2meta := MockEC2Meta(map[string]string{"instance-id": "i-1234"}, nil, "") assert.Equal(t, "i-1234", must(ec2meta.Meta("instance-id"))) assert.Equal(t, "i-1234", must(ec2meta.Meta("instance-id", "unused default"))) } func TestDynamic_MissingKey(t *testing.T) { - server, ec2meta := MockServer(404, "") - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "") assert.Empty(t, must(ec2meta.Dynamic("foo"))) assert.Equal(t, "default", must(ec2meta.Dynamic("foo", "default"))) } func TestDynamic_ValidKey(t *testing.T) { - server, ec2meta := MockServer(200, "i-1234") - defer server.Close() + ec2meta := MockEC2Meta(nil, map[string]string{"instance-id": "i-1234"}, "") assert.Equal(t, "i-1234", must(ec2meta.Dynamic("instance-id"))) assert.Equal(t, "i-1234", must(ec2meta.Dynamic("instance-id", "unused default"))) } func TestRegion_NoRegion(t *testing.T) { - server, ec2meta := MockServer(200, "{}") - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "") assert.Equal(t, "unknown", must(ec2meta.Region())) } func TestRegion_NoRegionWithDefault(t *testing.T) { - server, ec2meta := MockServer(200, "{}") - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "") assert.Equal(t, "foo", must(ec2meta.Region("foo"))) } func TestRegion_KnownRegion(t *testing.T) { - server, ec2meta := MockServer(200, `{"region":"us-east-1"}`) - defer server.Close() + ec2meta := MockEC2Meta(nil, nil, "us-east-1") assert.Equal(t, "us-east-1", must(ec2meta.Region())) } diff --git a/aws/testutils.go b/aws/testutils.go index 9f4c0240..bc57e117 100644 --- a/aws/testutils.go +++ b/aws/testutils.go @@ -2,42 +2,23 @@ package aws import ( "fmt" - "net/http" - "net/http/httptest" - "net/url" "github.com/aws/aws-sdk-go/service/ec2" ) -// MockServer - -func MockServer(code int, body string) (*httptest.Server, *Ec2Meta) { - server, httpClient := MockHTTPServer(code, body) - - client := &Ec2Meta{ - Client: httpClient, - cache: make(map[string]string), - Endpoint: server.URL + "/", - options: ClientOptions{}, - nonAWS: false, - } - return server, client -} - -// MockHTTPServer - -func MockHTTPServer(code int, body string) (*httptest.Server, *http.Client) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(code) - // nolint: errcheck - fmt.Fprintln(w, body) - })) - - tr := &http.Transport{ - Proxy: func(req *http.Request) (*url.URL, error) { - return url.Parse(server.URL) +// MockEC2Meta - +func MockEC2Meta(data map[string]string, dynamicData map[string]string, region string) *Ec2Meta { + return &Ec2Meta{ + metadataCache: map[string]string{}, + dynamicdataCache: map[string]string{}, + ec2MetadataProvider: func() (EC2Metadata, error) { + return &DummEC2MetadataProvider{ + data: data, + dynamicData: dynamicData, + region: region, + }, nil }, } - httpClient := &http.Client{Transport: tr} - return server, httpClient } // NewDummyEc2Info - @@ -45,13 +26,16 @@ func NewDummyEc2Info(metaClient *Ec2Meta) *Ec2Info { i := &Ec2Info{ metaClient: metaClient, describer: func() (InstanceDescriber, error) { return DummyInstanceDescriber{}, nil }, + cache: map[string]interface{}{}, } return i } // NewDummyEc2Meta - func NewDummyEc2Meta() *Ec2Meta { - return &Ec2Meta{nonAWS: true} + return &Ec2Meta{ + nonAWS: true, + } } // DummyInstanceDescriber - test doubles @@ -74,3 +58,29 @@ func (d DummyInstanceDescriber) DescribeInstances(*ec2.DescribeInstancesInput) ( } return output, nil } + +type DummEC2MetadataProvider struct { + data map[string]string + dynamicData map[string]string + region string +} + +func (d DummEC2MetadataProvider) GetMetadata(p string) (string, error) { + v, ok := d.data[p] + if !ok { + return "", fmt.Errorf("cannot find %v", p) + } + return v, nil +} + +func (d DummEC2MetadataProvider) GetDynamicData(p string) (string, error) { + v, ok := d.dynamicData[p] + if !ok { + return "", fmt.Errorf("cannot find %v", p) + } + return v, nil +} + +func (d DummEC2MetadataProvider) Region() (string, error) { + return d.region, nil +} diff --git a/gomplate_test.go b/gomplate_test.go index 0d621ab7..28adbf8d 100644 --- a/gomplate_test.go +++ b/gomplate_test.go @@ -3,7 +3,6 @@ package gomplate import ( "bytes" "context" - "net/http/httptest" "os" "path/filepath" "testing" @@ -53,26 +52,23 @@ func TestBoolTemplates(t *testing.T) { } func TestEc2MetaTemplates(t *testing.T) { - createGomplate := func(status int, body string) (*gomplate, *httptest.Server) { - server, ec2meta := aws.MockServer(status, body) - return &gomplate{funcMap: template.FuncMap{"ec2meta": ec2meta.Meta}}, server + createGomplate := func(data map[string]string, region string) *gomplate { + ec2meta := aws.MockEC2Meta(data, nil, region) + return &gomplate{funcMap: template.FuncMap{"ec2meta": ec2meta.Meta}} } - g, s := createGomplate(404, "") - defer s.Close() + g := createGomplate(nil, "") assert.Equal(t, "", testTemplate(t, g, `{{ec2meta "foo"}}`)) assert.Equal(t, "default", testTemplate(t, g, `{{ec2meta "foo" "default"}}`)) - s.Close() - g, s = createGomplate(200, "i-1234") - defer s.Close() + g = createGomplate(map[string]string{"instance-id": "i-1234"}, "") assert.Equal(t, "i-1234", testTemplate(t, g, `{{ec2meta "instance-id"}}`)) assert.Equal(t, "i-1234", testTemplate(t, g, `{{ec2meta "instance-id" "default"}}`)) } func TestEc2MetaTemplates_WithJSON(t *testing.T) { - server, ec2meta := aws.MockServer(200, `{"foo":"bar"}`) - defer server.Close() + ec2meta := aws.MockEC2Meta(map[string]string{"obj": `"foo": "bar"`}, map[string]string{"obj": `"foo": "baz"`}, "") + g := &gomplate{ funcMap: template.FuncMap{ "ec2meta": ec2meta.Meta, @@ -82,7 +78,7 @@ func TestEc2MetaTemplates_WithJSON(t *testing.T) { } assert.Equal(t, "bar", testTemplate(t, g, `{{ (ec2meta "obj" | json).foo }}`)) - assert.Equal(t, "bar", testTemplate(t, g, `{{ (ec2dynamic "obj" | json).foo }}`)) + assert.Equal(t, "baz", testTemplate(t, g, `{{ (ec2dynamic "obj" | json).foo }}`)) } func TestJSONArrayTemplates(t *testing.T) { -- cgit v1.2.3