summaryrefslogtreecommitdiff
path: root/aws/sts.go
blob: 3883de42d9577fa121888741a95f3601534152ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package aws

import (
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sts"
)

// STS -
type STS struct {
	identifier func() CallerIdentitifier
	cache      map[string]any
}

var identifierClient CallerIdentitifier

// CallerIdentitifier - an interface to wrap GetCallerIdentity
type CallerIdentitifier interface {
	GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error)
}

// NewSTS -
func NewSTS(_ ClientOptions) *STS {
	return &STS{
		identifier: func() CallerIdentitifier {
			if identifierClient == nil {
				session := SDKSession()
				identifierClient = sts.New(session)
			}
			return identifierClient
		},
		cache: make(map[string]any),
	}
}

func (s *STS) getCallerID() (*sts.GetCallerIdentityOutput, error) {
	i := s.identifier()
	if val, ok := s.cache["GetCallerIdentity"]; ok {
		if c, ok := val.(*sts.GetCallerIdentityOutput); ok {
			return c, nil
		}
	}
	in := &sts.GetCallerIdentityInput{}
	out, err := i.GetCallerIdentity(in)
	if err != nil {
		return nil, err
	}
	s.cache["GetCallerIdentity"] = out
	return out, nil
}

// UserID -
func (s *STS) UserID() (string, error) {
	cid, err := s.getCallerID()
	if err != nil {
		return "", err
	}
	return aws.StringValue(cid.UserId), nil
}

// Account -
func (s *STS) Account() (string, error) {
	cid, err := s.getCallerID()
	if err != nil {
		return "", err
	}
	return aws.StringValue(cid.Account), nil
}

// Arn -
func (s *STS) Arn() (string, error) {
	cid, err := s.getCallerID()
	if err != nil {
		return "", err
	}
	return aws.StringValue(cid.Arn), nil
}