diff options
| author | Ishita Sequeira <46771830+ishitasequeira@users.noreply.github.com> | 2024-11-13 08:57:55 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-13 08:57:55 -0500 |
| commit | cde58c56ff0981df884433eeea76d747f4b809de (patch) | |
| tree | 6fb745e6995e3bc3fe87be90524a4d39ebe3531a | |
| parent | 1d04e9aa3e237530336bc0693eaaa7f6de7b66e1 (diff) | |
refactor: Add registry pkg registry scanner (#933)
Signed-off-by: Ishita Sequeira <ishiseq29@gmail.com>
24 files changed, 3521 insertions, 4 deletions
diff --git a/registry-scanner/go.mod b/registry-scanner/go.mod index 3e0dd95..00a1589 100644 --- a/registry-scanner/go.mod +++ b/registry-scanner/go.mod @@ -3,15 +3,20 @@ module github.com/argoproj-labs/argocd-image-updater/registry-scanner go 1.22.3 require ( - github.com/argoproj-labs/argocd-image-updater v0.14.0 + github.com/distribution/distribution/v3 v3.0.0-20230722181636-7b502560cad4 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 + k8s.io/api v0.31.2 + k8s.io/apimachinery v0.31.2 + k8s.io/client-go v0.31.2 + sigs.k8s.io/kustomize/api v0.12.1 + sigs.k8s.io/kustomize/kyaml v0.13.9 ) require ( github.com/Masterminds/semver/v3 v3.2.1 - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.20.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/sys v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/registry-scanner/go.sum b/registry-scanner/go.sum index 16441ee..0eff022 100644 --- a/registry-scanner/go.sum +++ b/registry-scanner/go.sum @@ -5,6 +5,8 @@ github.com/argoproj-labs/argocd-image-updater v0.14.0/go.mod h1:PSVBweUoS6ogVFAi github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/distribution/distribution/v3 v3.0.0-20230722181636-7b502560cad4/go.mod h1:+fqBJ4vPYo4Uu1ZE4d+bUtTLRXfdSL3NvCZIZ9GHv58= github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -19,6 +21,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -30,6 +33,7 @@ golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= @@ -43,13 +47,22 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.26.11 h1:hLhTZRdYc3vBBOY4wbEyTLWgMyieOAk2Ws9NG57QqO4= k8s.io/api v0.26.11/go.mod h1:bSr/A0TKRt5W2OMDdexkM/ER1NxOxiQqNNFXW2nMZrM= +k8s.io/api v0.31.2 h1:3wLBbL5Uom/8Zy98GRPXpJ254nEFpl+hwndmk9RwmL0= +k8s.io/api v0.31.2/go.mod h1:bWmGvrGPssSK1ljmLzd3pwCQ9MgoTsRCuK35u6SygUk= k8s.io/apimachinery v0.26.11 h1:w//840HHdwSRKqD15j9YX9HLlU6RPlfrvW0xEhLk2+0= k8s.io/apimachinery v0.26.11/go.mod h1:2/HZp0l6coXtS26du1Bk36fCuAEr/lVs9Q9NbpBtd1Y= +k8s.io/apimachinery v0.31.2/go.mod h1:rsPdaZJfTfLsNJSQzNHQvYoTmxhoOEofxtOsF3rtsMo= +k8s.io/client-go v0.31.2 h1:Y2F4dxU5d3AQj+ybwSMqQnpZH9F30//1ObxOKlTI9yc= +k8s.io/client-go v0.31.2/go.mod h1:NPa74jSVR/+eez2dFsEIHNa+3o09vtNaWwWwb1qSxSs= +k8s.io/client-go v1.5.2 h1:JOxmv4FxrCIOS54kAABbN8/hA9jqGpns+Zc6soNgd8U= +k8s.io/client-go v1.5.2/go.mod h1:OmM68YRko3DQ0sjlnWxzjQF9lcSLHJXuGMTo23rc7wI= k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/utils v0.0.0-20230220204549-a5ecb0141aa5 h1:kmDqav+P+/5e1i9tFfHq1qcF3sOrDp+YEkVDAHu7Jwk= k8s.io/utils v0.0.0-20230220204549-a5ecb0141aa5/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/kustomize/api v0.12.1/go.mod h1:y3JUhimkZkR6sbLNwfJHxvo1TCLwuwm14sCYnkH6S1s= +sigs.k8s.io/kustomize/kyaml v0.13.9/go.mod h1:QsRbD0/KcU+wdk0/L0fIp2KLnohkVzs6fQ85/nOXac4= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= diff --git a/registry-scanner/pkg/cache/cache.go b/registry-scanner/pkg/cache/cache.go new file mode 100644 index 0000000..1ceef66 --- /dev/null +++ b/registry-scanner/pkg/cache/cache.go @@ -0,0 +1,20 @@ +package cache + +import ( + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" +) + +type ImageTagCache interface { + HasTag(imageName string, imageTag string) bool + GetTag(imageName string, imageTag string) (*tag.ImageTag, error) + SetTag(imageName string, imgTag *tag.ImageTag) + ClearCache() + NumEntries() int +} + +// KnownImage represents a known image and the applications using it, without +// any version/tag information. +type KnownImage struct { + ImageName string + // Applications []string +} diff --git a/registry-scanner/pkg/cache/memcache.go b/registry-scanner/pkg/cache/memcache.go new file mode 100644 index 0000000..03b4080 --- /dev/null +++ b/registry-scanner/pkg/cache/memcache.go @@ -0,0 +1,74 @@ +package cache + +import ( + "fmt" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" + + memcache "github.com/patrickmn/go-cache" +) + +type MemCache struct { + cache *memcache.Cache +} + +// NewMemCache returns a new instance of MemCache +func NewMemCache() ImageTagCache { + mc := MemCache{} + c := memcache.New(0, 0) + mc.cache = c + return &mc +} + +// HasTag returns true if cache has entry for given tag, false if not +func (mc *MemCache) HasTag(imageName string, tagName string) bool { + tag, err := mc.GetTag(imageName, tagName) + if err != nil || tag == nil { + return false + } else { + return true + } +} + +// SetTag sets a tag entry into the cache +func (mc *MemCache) SetTag(imageName string, imgTag *tag.ImageTag) { + mc.cache.Set(tagCacheKey(imageName, imgTag.TagName), *imgTag, -1) +} + +// GetTag gets a tag entry from the cache +func (mc *MemCache) GetTag(imageName string, tagName string) (*tag.ImageTag, error) { + var imgTag tag.ImageTag + e, ok := mc.cache.Get(tagCacheKey(imageName, tagName)) + if !ok { + return nil, nil + } + imgTag, ok = e.(tag.ImageTag) + if !ok { + return nil, fmt.Errorf("") + } + return &imgTag, nil +} + +func (mc *MemCache) SetImage(imageName, application string) { + mc.cache.Set(imageCacheKey(imageName), application, -1) +} + +// ClearCache clears the cache +func (mc *MemCache) ClearCache() { + for k := range mc.cache.Items() { + mc.cache.Delete(k) + } +} + +// NumEntries returns the number of entries in the cache +func (mc *MemCache) NumEntries() int { + return mc.cache.ItemCount() +} + +func tagCacheKey(imageName, imageTag string) string { + return fmt.Sprintf("tags:%s:%s", imageName, imageTag) +} + +func imageCacheKey(imageName string) string { + return fmt.Sprintf("image:%s", imageName) +} diff --git a/registry-scanner/pkg/cache/memcache_test.go b/registry-scanner/pkg/cache/memcache_test.go new file mode 100644 index 0000000..8fcb47b --- /dev/null +++ b/registry-scanner/pkg/cache/memcache_test.go @@ -0,0 +1,70 @@ +package cache + +import ( + "testing" + "time" + + memcache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/argoproj-labs/argocd-image-updater/pkg/tag" +) + +func Test_MemCache(t *testing.T) { + imageName := "foo/bar" + imageTag := "v1.0.0" + t.Run("Cache hit", func(t *testing.T) { + mc := NewMemCache() + newTag := tag.NewImageTag(imageTag, time.Unix(0, 0), "") + mc.SetTag(imageName, newTag) + cachedTag, err := mc.GetTag(imageName, imageTag) + require.NoError(t, err) + require.NotNil(t, cachedTag) + assert.Equal(t, imageTag, cachedTag.TagName) + assert.True(t, mc.HasTag(imageName, imageTag)) + assert.Equal(t, 1, mc.NumEntries()) + }) + + t.Run("Cache miss", func(t *testing.T) { + mc := NewMemCache() + newTag := tag.NewImageTag(imageTag, time.Unix(0, 0), "") + mc.SetTag(imageName, newTag) + assert.Equal(t, 1, mc.NumEntries()) + cachedTag, err := mc.GetTag(imageName, "v1.0.1") + require.NoError(t, err) + require.Nil(t, cachedTag) + assert.False(t, mc.HasTag(imageName, "v1.0.1")) + }) + + t.Run("Cache clear", func(t *testing.T) { + mc := NewMemCache() + newTag := tag.NewImageTag(imageTag, time.Unix(0, 0), "") + mc.SetTag(imageName, newTag) + cachedTag, err := mc.GetTag(imageName, imageTag) + require.NoError(t, err) + require.NotNil(t, cachedTag) + assert.Equal(t, imageTag, cachedTag.TagName) + assert.True(t, mc.HasTag(imageName, imageTag)) + assert.Equal(t, 1, mc.NumEntries()) + mc.ClearCache() + assert.Equal(t, 0, mc.NumEntries()) + cachedTag, err = mc.GetTag(imageName, imageTag) + require.NoError(t, err) + require.Nil(t, cachedTag) + }) + t.Run("Image Cache Key", func(t *testing.T) { + mc := MemCache{ + cache: memcache.New(0, 0), + } + application := "application1" + key := imageCacheKey(imageName) + mc.SetImage(imageName, application) + app, b := mc.cache.Get(key) + assert.True(t, b) + assert.Equal(t, application, app) + assert.Equal(t, 1, mc.NumEntries()) + mc.ClearCache() + assert.Equal(t, 0, mc.NumEntries()) + }) +} diff --git a/registry-scanner/pkg/kube/kubernetes.go b/registry-scanner/pkg/kube/kubernetes.go new file mode 100644 index 0000000..6771440 --- /dev/null +++ b/registry-scanner/pkg/kube/kubernetes.go @@ -0,0 +1,85 @@ +package kube + +// Kubernetes client related code + +import ( + "context" + "fmt" + "os" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + _ "k8s.io/client-go/plugin/pkg/client/auth" + "k8s.io/client-go/tools/clientcmd" +) + +type KubernetesClient struct { + Clientset kubernetes.Interface + Context context.Context + Namespace string +} + +func NewKubernetesClient(ctx context.Context, client kubernetes.Interface, namespace string) *KubernetesClient { + kc := &KubernetesClient{} + kc.Context = ctx + kc.Clientset = client + kc.Namespace = namespace + return kc +} + +// NewKubernetesClient creates a new Kubernetes client object from given +// configuration file. If configuration file is the empty string, in-cluster +// client will be created. +func NewKubernetesClientFromConfig(ctx context.Context, namespace string, kubeconfig string) (*KubernetesClient, error) { + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() + loadingRules.DefaultClientConfig = &clientcmd.DefaultClientConfig + loadingRules.ExplicitPath = kubeconfig + overrides := clientcmd.ConfigOverrides{} + clientConfig := clientcmd.NewInteractiveDeferredLoadingClientConfig(loadingRules, &overrides, os.Stdin) + + config, err := clientConfig.ClientConfig() + if err != nil { + return nil, err + } + + if namespace == "" { + namespace, _, err = clientConfig.Namespace() + if err != nil { + return nil, err + } + } + + clientset, err := kubernetes.NewForConfig(config) + if err != nil { + return nil, err + } + + applicationsClientset, err := versioned.NewForConfig(config) + if err != nil { + return nil, err + } + + return NewKubernetesClient(ctx, clientset, applicationsClientset, namespace), nil +} + +// GetSecretData returns the raw data from named K8s secret in given namespace +func (client *KubernetesClient) GetSecretData(namespace string, secretName string) (map[string][]byte, error) { + secret, err := client.Clientset.CoreV1().Secrets(namespace).Get(client.Context, secretName, metav1.GetOptions{}) + if err != nil { + return nil, err + } + return secret.Data, nil +} + +// GetSecretField returns the value of a field from named K8s secret in given namespace +func (client *KubernetesClient) GetSecretField(namespace string, secretName string, field string) (string, error) { + secret, err := client.GetSecretData(namespace, secretName) + if err != nil { + return "", err + } + if data, ok := secret[field]; !ok { + return "", fmt.Errorf("secret '%s/%s' does not have a field '%s'", namespace, secretName, field) + } else { + return string(data), nil + } +} diff --git a/registry-scanner/pkg/kube/kubernetes_test.go b/registry-scanner/pkg/kube/kubernetes_test.go new file mode 100644 index 0000000..c87f94f --- /dev/null +++ b/registry-scanner/pkg/kube/kubernetes_test.go @@ -0,0 +1,68 @@ +package kube + +import ( + "context" + "testing" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/test/fake" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/test/fixture" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewKubernetesClient(t *testing.T) { + t.Run("Get new K8s client for remote cluster instance", func(t *testing.T) { + client, err := NewKubernetesClientFromConfig(context.TODO(), "", "../../test/testdata/kubernetes/config") + require.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, "default", client.Namespace) + }) + + t.Run("Get new K8s client for remote cluster instance specified namespace", func(t *testing.T) { + client, err := NewKubernetesClientFromConfig(context.TODO(), "argocd", "../../test/testdata/kubernetes/config") + require.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, "argocd", client.Namespace) + }) +} + +func Test_GetDataFromSecrets(t *testing.T) { + t.Run("Get all data from dummy secret", func(t *testing.T) { + secret := fixture.MustCreateSecretFromFile("../../test/testdata/resources/dummy-secret.json") + clientset := fake.NewFakeClientsetWithResources(secret) + client := &KubernetesClient{Clientset: clientset} + data, err := client.GetSecretData("test-namespace", "test-secret") + require.NoError(t, err) + require.NotNil(t, data) + assert.Len(t, data, 1) + assert.Equal(t, "argocd", string(data["namespace"])) + }) + + t.Run("Get string data from dummy secret existing field", func(t *testing.T) { + secret := fixture.MustCreateSecretFromFile("../../test/testdata/resources/dummy-secret.json") + clientset := fake.NewFakeClientsetWithResources(secret) + client := &KubernetesClient{Clientset: clientset} + data, err := client.GetSecretField("test-namespace", "test-secret", "namespace") + require.NoError(t, err) + assert.Equal(t, "argocd", data) + }) + + t.Run("Get string data from dummy secret non-existing field", func(t *testing.T) { + secret := fixture.MustCreateSecretFromFile("../../test/testdata/resources/dummy-secret.json") + clientset := fake.NewFakeClientsetWithResources(secret) + client := &KubernetesClient{Clientset: clientset} + data, err := client.GetSecretField("test-namespace", "test-secret", "nonexisting") + require.Error(t, err) + require.Empty(t, data) + }) + + t.Run("Get string data from non-existing secret non-existing field", func(t *testing.T) { + secret := fixture.MustCreateSecretFromFile("../../test/testdata/resources/dummy-secret.json") + clientset := fake.NewFakeClientsetWithResources(secret) + client := &KubernetesClient{Clientset: clientset} + data, err := client.GetSecretField("test-namespace", "test", "namespace") + require.Error(t, err) + require.Empty(t, data) + }) +} diff --git a/registry-scanner/pkg/registry/client.go b/registry-scanner/pkg/registry/client.go new file mode 100644 index 0000000..7a26946 --- /dev/null +++ b/registry-scanner/pkg/registry/client.go @@ -0,0 +1,447 @@ +package registry + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + + "github.com/argoproj/pkg/json" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/options" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" + + "github.com/distribution/distribution/v3" + "github.com/distribution/distribution/v3/manifest/manifestlist" + "github.com/distribution/distribution/v3/manifest/ocischema" + "github.com/distribution/distribution/v3/manifest/schema1" //nolint:staticcheck + "github.com/distribution/distribution/v3/manifest/schema2" + "github.com/distribution/distribution/v3/reference" + "github.com/distribution/distribution/v3/registry/client" + "github.com/distribution/distribution/v3/registry/client/auth" + "github.com/distribution/distribution/v3/registry/client/auth/challenge" + "github.com/distribution/distribution/v3/registry/client/transport" + + "github.com/opencontainers/go-digest" + ociv1 "github.com/opencontainers/image-spec/specs-go/v1" + + "go.uber.org/ratelimit" + + "net/http" + "net/url" + "strings" +) + +// TODO: Check image's architecture and OS + +// knownMediaTypes is the list of media types we can process +var knownMediaTypes = []string{ + ocischema.SchemaVersion.MediaType, + schema1.MediaTypeSignedManifest, //nolint:staticcheck + schema2.SchemaVersion.MediaType, + manifestlist.SchemaVersion.MediaType, + ociv1.MediaTypeImageIndex, +} + +// RegistryClient defines the methods we need for querying container registries +type RegistryClient interface { + NewRepository(nameInRepository string) error + Tags() ([]string, error) + ManifestForTag(tagStr string) (distribution.Manifest, error) + ManifestForDigest(dgst digest.Digest) (distribution.Manifest, error) + TagMetadata(manifest distribution.Manifest, opts *options.ManifestOptions) (*tag.TagInfo, error) +} + +type NewRegistryClient func(*RegistryEndpoint, string, string) (RegistryClient, error) + +// Helper type for registry clients +type registryClient struct { + regClient distribution.Repository + endpoint *RegistryEndpoint + creds credentials +} + +// credentials is an implementation of distribution/V3/session struct +// to manage registry credentials and token +type credentials struct { + username string + password string + refreshTokens map[string]string +} + +func (c credentials) Basic(url *url.URL) (string, string) { + return c.username, c.password +} + +func (c credentials) RefreshToken(url *url.URL, service string) string { + return c.refreshTokens[service] +} + +func (c credentials) SetRefreshToken(realm *url.URL, service, token string) { + if c.refreshTokens != nil { + c.refreshTokens[service] = token + } +} + +// rateLimitTransport encapsulates our custom HTTP round tripper with rate +// limiter from the endpoint. +type rateLimitTransport struct { + limiter ratelimit.Limiter + transport http.RoundTripper + endpoint *RegistryEndpoint +} + +// RoundTrip is a custom RoundTrip method with rate-limiter +func (rlt *rateLimitTransport) RoundTrip(r *http.Request) (*http.Response, error) { + rlt.limiter.Take() + log.Tracef("Performing HTTP %s %s", r.Method, r.URL) + resp, err := rlt.transport.RoundTrip(r) + return resp, err +} + +// NewRepository is a wrapper for creating a registry client that is possibly +// rate-limited by using a custom HTTP round tripper method. +func (clt *registryClient) NewRepository(nameInRepository string) error { + urlToCall := strings.TrimSuffix(clt.endpoint.RegistryAPI, "/") + challengeManager1 := challenge.NewSimpleManager() + _, err := ping(challengeManager1, clt.endpoint, "") + if err != nil { + return err + } + + authTransport := transport.NewTransport( + clt.endpoint.GetTransport(), auth.NewAuthorizer( + challengeManager1, + auth.NewTokenHandler(clt.endpoint.GetTransport(), clt.creds, nameInRepository, "pull"), + auth.NewBasicHandler(clt.creds))) + + rlt := &rateLimitTransport{ + limiter: clt.endpoint.Limiter, + transport: authTransport, + endpoint: clt.endpoint, + } + + named, err := reference.WithName(nameInRepository) + if err != nil { + return err + } + clt.regClient, err = client.NewRepository(named, urlToCall, rlt) + if err != nil { + return err + } + return nil +} + +// NewClient returns a new RegistryClient for the given endpoint information +func NewClient(endpoint *RegistryEndpoint, username, password string) (RegistryClient, error) { + if username == "" && endpoint.Username != "" { + username = endpoint.Username + } + if password == "" && endpoint.Password != "" { + password = endpoint.Password + } + creds := credentials{ + username: username, + password: password, + } + return ®istryClient{ + creds: creds, + endpoint: endpoint, + }, nil +} + +// Tags returns a list of tags for given name in repository +func (clt *registryClient) Tags() ([]string, error) { + tagService := clt.regClient.Tags(context.Background()) + tTags, err := tagService.All(context.Background()) + if err != nil { + return nil, err + } + return tTags, nil +} + +// Manifest returns a Manifest for a given tag in repository +func (clt *registryClient) ManifestForTag(tagStr string) (distribution.Manifest, error) { + manService, err := clt.regClient.Manifests(context.Background()) + if err != nil { + return nil, err + } + manifest, err := manService.Get( + context.Background(), + digest.FromString(tagStr), + distribution.WithTag(tagStr), distribution.WithManifestMediaTypes(knownMediaTypes)) + if err != nil { + return nil, err + } + return manifest, nil +} + +// ManifestForDigest returns a Manifest for a given digest in repository +func (clt *registryClient) ManifestForDigest(dgst digest.Digest) (distribution.Manifest, error) { + manService, err := clt.regClient.Manifests(context.Background()) + if err != nil { + return nil, err + } + manifest, err := manService.Get( + context.Background(), + dgst, + distribution.WithManifestMediaTypes(knownMediaTypes)) + if err != nil { + return nil, err + } + return manifest, nil +} + +// TagMetadata retrieves metadata for a given manifest of given repository +func (client *registryClient) TagMetadata(manifest distribution.Manifest, opts *options.ManifestOptions) (*tag.TagInfo, error) { + ti := &tag.TagInfo{} + logCtx := opts.Logger() + var info struct { + Arch string `json:"architecture"` + Created string `json:"created"` + OS string `json:"os"` + Variant string `json:"variant"` + } + + // We support the following types of manifests as returned by the registry: + // + // V1 (legacy, might go away), V2 and OCI + // + // Also ManifestLists (e.g. on multi-arch images) are supported. + // + switch deserialized := manifest.(type) { + + case *schema1.SignedManifest: //nolint:staticcheck + var man schema1.Manifest = deserialized.Manifest //nolint:staticcheck + if len(man.History) == 0 { + return nil, fmt.Errorf("no history information found in schema V1") + } + + _, mBytes, err := manifest.Payload() + if err != nil { + return nil, err + } + ti.Digest = sha256.Sum256(mBytes) + + logCtx.Tracef("v1 SHA digest is %s", ti.EncodedDigest()) + if err := json.Unmarshal([]byte(man.History[0].V1Compatibility), &info); err != nil { + return nil, err + } + if !opts.WantsPlatform(info.OS, info.Arch, "") { + logCtx.Debugf("ignoring v1 manifest %v. Manifest platform: %s, requested: %s", + ti.EncodedDigest(), options.PlatformKey(info.OS, info.Arch, info.Variant), strings.Join(opts.Platforms(), ",")) + return nil, nil + } + if createdAt, err := time.Parse(time.RFC3339Nano, info.Created); err != nil { + return nil, err + } else { + ti.CreatedAt = createdAt + } + return ti, nil + + case *manifestlist.DeserializedManifestList: + var list manifestlist.DeserializedManifestList = *deserialized + + // List must contain at least one image manifest + if len(list.Manifests) == 0 { + return nil, fmt.Errorf("empty manifestlist not supported") + } + + // We use the SHA from the manifest list to let the container engine + // decide which image to pull, in case of multi-arch clusters. + _, mBytes, err := list.Payload() + if err != nil { + return nil, fmt.Errorf("could not retrieve manifestlist payload: %v", err) + } + ti.Digest = sha256.Sum256(mBytes) + + logCtx.Tracef("SHA256 of manifest parent is %v", ti.EncodedDigest()) + + return TagInfoFromReferences(client, opts, logCtx, ti, list.References()) + + case *ocischema.DeserializedImageIndex: + var index ocischema.DeserializedImageIndex = *deserialized + + // Index must contain at least one image manifest + if len(index.Manifests) == 0 { + return nil, fmt.Errorf("empty index not supported") + } + + // We use the SHA from the manifest index to let the container engine + // decide which image to pull, in case of multi-arch clusters. + _, mBytes, err := index.Payload() + if err != nil { + return nil, fmt.Errorf("could not retrieve index payload: %v", err) + } + ti.Digest = sha256.Sum256(mBytes) + + logCtx.Tracef("SHA256 of manifest parent is %v", ti.EncodedDigest()) + + return TagInfoFromReferences(client, opts, logCtx, ti, index.References()) + + case *schema2.DeserializedManifest: + var man schema2.Manifest = deserialized.Manifest + + logCtx.Tracef("Manifest digest is %v", man.Config.Digest.Encoded()) + + _, mBytes, err := manifest.Payload() + if err != nil { + return nil, err + } + ti.Digest = sha256.Sum256(mBytes) + logCtx.Tracef("v2 SHA digest is %s", ti.EncodedDigest()) + + // The data we require from a V2 manifest is in a blob that we need to + // fetch from the registry. + blobReader, err := client.regClient.Blobs(context.Background()).Get(context.Background(), man.Config.Digest) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(blobReader, &info); err != nil { + return nil, err + } + + if !opts.WantsPlatform(info.OS, info.Arch, info.Variant) { + logCtx.Debugf("ignoring v2 manifest %v. Manifest platform: %s, requested: %s", + ti.EncodedDigest(), options.PlatformKey(info.OS, info.Arch, info.Variant), strings.Join(opts.Platforms(), ",")) + return nil, nil + } + + if ti.CreatedAt, err = time.Parse(time.RFC3339Nano, info.Created); err != nil { + return nil, err + } + + return ti, nil + case *ocischema.DeserializedManifest: + var man ocischema.Manifest = deserialized.Manifest + + _, mBytes, err := manifest.Payload() + if err != nil { + return nil, err + } + ti.Digest = sha256.Sum256(mBytes) + logCtx.Tracef("OCI SHA digest is %s", ti.EncodedDigest()) + + // The data we require from a V2 manifest is in a blob that we need to + // fetch from the registry. + blobReader, err := client.regClient.Blobs(context.Background()).Get(context.Background(), man.Config.Digest) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(blobReader, &info); err != nil { + return nil, err + } + + if !opts.WantsPlatform(info.OS, info.Arch, info.Variant) { + logCtx.Debugf("ignoring OCI manifest %v. Manifest platform: %s, requested: %s", + ti.EncodedDigest(), options.PlatformKey(info.OS, info.Arch, info.Variant), strings.Join(opts.Platforms(), ",")) + return nil, nil + } + + if ti.CreatedAt, err = time.Parse(time.RFC3339Nano, info.Created); err != nil { + return nil, err + } + + return ti, nil + default: + return nil, fmt.Errorf("invalid manifest type %T", manifest) + } +} + +// TagInfoFromReferences is a helper method to retrieve metadata for a given +// list of references. It will return the most recent pushed manifest from the +// list of references. +func TagInfoFromReferences(client *registryClient, opts *options.ManifestOptions, logCtx *log.LogContext, ti *tag.TagInfo, references []distribution.Descriptor) (*tag.TagInfo, error) { + var ml []distribution.Descriptor + platforms := []string{} + + for _, ref := range references { + var refOS, refArch, refVariant string + if ref.Platform != nil { + refOS = ref.Platform.OS + refArch = ref.Platform.Architecture + refVariant = ref.Platform.Variant + } + platform1 := options.PlatformKey(refOS, refArch, refVariant) + platforms = append(platforms, platform1) + logCtx.Tracef("Found %s", platform1) + if !opts.WantsPlatform(refOS, refArch, refVariant) { + logCtx.Tracef("Ignoring referenced manifest %v because platform %s does not match any of: %s", + ref.Digest, + platform1, + strings.Join(opts.Platforms(), ",")) + continue + } + ml = append(ml, ref) + } + + // We need at least one reference that matches requested platforms + if len(ml) == 0 { + logCtx.Debugf("Manifest list did not contain any usable reference. Platforms requested: (%s), platforms included: (%s)", + strings.Join(opts.Platforms(), ","), strings.Join(platforms, ",")) + return nil, nil + } + + // For some strategies, we do not need to fetch metadata for further + // processing. + if !opts.WantsMetadata() { + return ti, nil + } + + // Loop through all referenced manifests to get their metadata. We only + // consider manifests for platforms we are interested in. + for _, ref := range ml { + logCtx.Tracef("Inspecting metadata of reference: %v", ref.Digest) + + man, err := client.ManifestForDigest(ref.Digest) + if err != nil { + return nil, fmt.Errorf("could not fetch manifest %v: %v", ref.Digest, err) + } + + cti, err := client.TagMetadata(man, opts) + if err != nil { + return nil, fmt.Errorf("could not fetch metadata for manifest %v: %v", ref.Digest, err) + } + + // We save the timestamp of the most recent pushed manifest for any + // given reference, if the metadata for the tag was correctly + // retrieved. This is important for the latest update strategy to + // be able to handle multi-arch images. The latest strategy will + // consider the most recent reference from an image index. + if cti != nil { + if cti.CreatedAt.After(ti.CreatedAt) { + ti.CreatedAt = cti.CreatedAt + } + } else { + logCtx.Warnf("returned metadata for manifest %v is nil, this should not happen.", ref.Digest) + continue + } + } + + return ti, nil +} + +// Implementation of ping method to initialize the challenge list +// Without this, tokenHandler and AuthorizationHandler won't work +func ping(manager challenge.Manager, endpoint *RegistryEndpoint, versionHeader string) ([]auth.APIVersion, error) { + httpc := &http.Client{Transport: endpoint.GetTransport()} + url := endpoint.RegistryAPI + "/v2/" + resp, err := httpc.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + // Let's consider only HTTP 200 and 401 valid responses for the initial request + if resp.StatusCode != 200 && resp.StatusCode != 401 { + return nil, fmt.Errorf("endpoint %s does not seem to be a valid v2 Docker Registry API (received HTTP code %d for GET %s)", endpoint.RegistryAPI, resp.StatusCode, url) + } + + if err := manager.AddResponse(resp); err != nil { + return nil, err + } + + return auth.APIVersions(resp, versionHeader), err +} diff --git a/registry-scanner/pkg/registry/client_test.go b/registry-scanner/pkg/registry/client_test.go new file mode 100644 index 0000000..074a3cd --- /dev/null +++ b/registry-scanner/pkg/registry/client_test.go @@ -0,0 +1,609 @@ +package registry + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/distribution/distribution/v3/manifest" + "github.com/distribution/distribution/v3/manifest/manifestlist" + "github.com/distribution/distribution/v3/manifest/ocischema" + "github.com/distribution/distribution/v3/manifest/schema2" + + "github.com/distribution/distribution/v3" + "github.com/distribution/distribution/v3/manifest/schema1" //nolint:staticcheck + v1 "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/options" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/registry/mocks" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" +) + +func TestBasic(t *testing.T) { + creds := credentials{ + username: "testuser", + password: "testpass", + } + + testURL, _ := url.Parse("https://example.com") + username, password := creds.Basic(testURL) + + if username != "testuser" { + t.Errorf("Expected username to be 'testuser', got '%s'", username) + } + if password != "testpass" { + t.Errorf("Expected password to be 'testpass', got '%s'", password) + } +} + +func TestNewRepository(t *testing.T) { + t.Run("Invalid Reference Format", func(t *testing.T) { + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + err = client.NewRepository("test@test") + require.Error(t, err) + assert.Contains(t, "invalid reference format", err.Error()) + + }) + t.Run("Success Ping", func(t *testing.T) { + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + err = client.NewRepository("test/test") + require.NoError(t, err) + }) + + t.Run("Fail Ping", func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + ep := &RegistryEndpoint{RegistryAPI: testServer.URL} + client, err := NewClient(ep, "", "") + require.NoError(t, err) + err = client.NewRepository("") + require.Error(t, err) + }) + +} + +func TestRoundTrip_Success(t *testing.T) { + // Create mocks + mockLimiter := new(mocks.Limiter) + mockTransport := new(mocks.RoundTripper) + endpoint := &RegistryEndpoint{RegistryAPI: "http://example.com"} + // Create an instance of rateLimitTransport with mocks + rlt := &rateLimitTransport{ + limiter: mockLimiter, + transport: mockTransport, + endpoint: endpoint, + } + // Create a sample HTTP request + req, err := http.NewRequest("GET", "http://example.com", nil) + assert.NoError(t, err) + resp := &http.Response{StatusCode: http.StatusOK} + // Set up expectations + mockLimiter.On("Take").Return(time.Now()) + mockTransport.On("RoundTrip", req).Return(resp, nil) + // Call the method under test + actualResp, err := rlt.RoundTrip(req) + // Assert the expectations + mockLimiter.AssertExpectations(t) + mockTransport.AssertExpectations(t) + assert.NoError(t, err) + assert.Equal(t, resp, actualResp) +} +func TestRoundTrip_Failure(t *testing.T) { + // Create mocks + mockLimiter := new(mocks.Limiter) + mockTransport := new(mocks.RoundTripper) + endpoint := &RegistryEndpoint{RegistryAPI: "http://example.com"} + // Create an instance of rateLimitTransport with mocks + rlt := &rateLimitTransport{ + limiter: mockLimiter, + transport: mockTransport, + endpoint: endpoint, + } + // Create a sample HTTP request + req := httptest.NewRequest("GET", "http://example.com", nil) + // Set up expectations + mockLimiter.On("Take").Return(time.Now()) + mockTransport.On("RoundTrip", req).Return(nil, errors.New("Error on caling func RoundTrip")) + // Call the method under test + actualResp, err := rlt.RoundTrip(req) + // Assert the expectations + mockLimiter.AssertExpectations(t) + mockTransport.AssertExpectations(t) + assert.Error(t, err) + assert.Nil(t, actualResp) +} + +func TestRefreshToken(t *testing.T) { + creds := credentials{ + refreshTokens: map[string]string{ + "service1": "token1", + }, + } + testURL, _ := url.Parse("https://example.com") + token := creds.RefreshToken(testURL, "service1") + if token != "token1" { + t.Errorf("Expected token to be 'token1', got '%s'", token) + } +} + +func TestSetRefreshToken(t *testing.T) { + creds := credentials{ + refreshTokens: make(map[string]string), + } + testURL, _ := url.Parse("https://example.com") + creds.SetRefreshToken(testURL, "service1", "token1") + + if token, exists := creds.refreshTokens["service1"]; !exists { + t.Error("Expected token for 'service1' to exist") + } else if token != "token1" { + t.Errorf("Expected token to be 'token1', got '%s'", token) + } +} +func TestNewClient(t *testing.T) { + t.Run("Create client with provided username and password", func(t *testing.T) { + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + _, err = NewClient(ep, "testuser", "pass") + require.NoError(t, err) + }) + t.Run("Create client with empty username and password", func(t *testing.T) { + ep := &RegistryEndpoint{Username: "testuser", Password: "pass"} + _, err := NewClient(ep, "", "") + require.NoError(t, err) + }) +} + +func TestTags(t *testing.T) { + t.Run("success", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + mockTagService := new(mocks.TagService) + mockTagService.On("All", mock.Anything).Return([]string{"testTag-1", "testTag-2"}, nil) + mockRegClient.On("Tags", mock.Anything).Return(mockTagService) + tags, err := client.Tags() + require.NoError(t, err) + assert.Contains(t, tags, "testTag-1") + assert.Contains(t, tags, "testTag-2") + }) + t.Run("Fail", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + mockTagService := new(mocks.TagService) + mockTagService.On("All", mock.Anything).Return([]string{}, errors.New("Error on caling func All")) + mockRegClient.On("Tags", mock.Anything).Return(mockTagService) + _, err := client.Tags() + require.Error(t, err) + }) +} + +func TestManifestForTag(t *testing.T) { + t.Run("Successful retrieval of Manifest", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + mockRegClient.On("Manifests", mock.Anything).Return(manService, nil) + _, err := client.ManifestForTag("tagStr") + require.NoError(t, err) + }) + t.Run("Error returned from Manifests call", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + mockRegClient.On("Manifests", mock.Anything).Return(manService, errors.New("Error on caling func Manifests")) + _, err := client.ManifestForTag("tagStr") + require.Error(t, err) + }) + + t.Run("Error returned from Get call", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("Error on caling func Get")) + mockRegClient.On("Manifests", mock.Anything).Return(manService, nil) + _, err := client.ManifestForTag("tagStr") + require.Error(t, err) + }) + +} + +func TestManifestForDigest(t *testing.T) { + t.Run("Successful retrieval of manifest", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + mockRegClient.On("Manifests", mock.Anything).Return(manService, nil) + _, err := client.ManifestForDigest("dgst") + require.NoError(t, err) + }) + t.Run("Error returned from Manifests call", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + mockRegClient.On("Manifests", mock.Anything).Return(manService, errors.New("Error on caling func Manifests")) + _, err := client.ManifestForDigest("dgst") + require.Error(t, err) + }) + t.Run("Error returned from Get call", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("Error on caling func Get")) + mockRegClient.On("Manifests", mock.Anything).Return(manService, nil) + _, err := client.ManifestForDigest("dgst") + require.Error(t, err) + }) +} + +func TestTagInfoFromReferences(t *testing.T) { + t.Run("No usable reference in manifest list", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + tagInfo := &tag.TagInfo{} + tagInfo.CreatedAt = time.Now() + tagInfo.Digest = [32]byte{} + opts := &options.ManifestOptions{} + opts.WithPlatform("testOS", "testArch", "testVarient") + opts.WithLogger(log.NewContext()) + opts.WithMetadata(true) + descriptor := []distribution.Descriptor{ + { + MediaType: "", + Digest: "", + Size: 0, + Platform: &v1.Platform{ + Architecture: "mTestArch", + OS: "mTestOS", + OSVersion: "mTestOSVersion", + OSFeatures: []string{}, + Variant: "mTestVarient", + }, + }, + } + tag, err := TagInfoFromReferences(&client, opts, log.NewContext(), tagInfo, descriptor) + require.Nil(t, tag) + require.NoError(t, err) + }) + t.Run("Return tagInfo when metadata option is false", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + tagInfo := &tag.TagInfo{} + tagInfo.CreatedAt = time.Now() + tagInfo.Digest = [32]byte{} + opts := &options.ManifestOptions{} + opts.WithMetadata(false) + opts.WithPlatform("testOS", "testArch", "testVarient") + opts.WithLogger(log.NewContext()) + descriptor := []distribution.Descriptor{ + { + MediaType: "", + Digest: "", + Size: 0, + Platform: &v1.Platform{ + Architecture: "testArch", + OS: "testOS", + OSVersion: "testOSVersion", + OSFeatures: []string{}, + Variant: "testVarient", + }, + }, + } + tag, err := TagInfoFromReferences(&client, opts, log.NewContext(), tagInfo, descriptor) + require.NoError(t, err) + assert.Equal(t, tag, tagInfo) + require.NoError(t, err) + }) + t.Run("Return error from ManifestForDigest", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + tagInfo := &tag.TagInfo{} + tagInfo.CreatedAt = time.Now() + tagInfo.Digest = [32]byte{} + opts := &options.ManifestOptions{} + opts.WithMetadata(true) + opts.WithPlatform("testOS", "testArch", "testVarient") + opts.WithLogger(log.NewContext()) + descriptor := []distribution.Descriptor{ + { + MediaType: "", + Digest: "", + Size: 0, + Platform: &v1.Platform{ + Architecture: "testArch", + OS: "testOS", + OSVersion: "testOSVersion", + OSFeatures: []string{}, + Variant: "testVarient", + }, + }, + } + mockRegClient.On("Manifests", mock.Anything).Return(nil, errors.New("Error from Manifests")) + _, err := TagInfoFromReferences(&client, opts, log.NewContext(), tagInfo, descriptor) + require.Error(t, err) + }) + t.Run("Return error from TagMetadata", func(t *testing.T) { + mockRegClient := new(mocks.Repository) + client := registryClient{ + regClient: mockRegClient, + } + tagInfo := &tag.TagInfo{} + tagInfo.CreatedAt = time.Now() + tagInfo.Digest = [32]byte{} + opts := &options.ManifestOptions{} + opts.WithMetadata(true) + opts.WithPlatform("testOS", "testArch", "testVarient") + opts.WithLogger(log.NewContext()) + descriptor := []distribution.Descriptor{ + { + MediaType: "", + Digest: "", + Size: 0, + Platform: &v1.Platform{ + Architecture: "testArch", + OS: "testOS", + OSVersion: "testOSVersion", + OSFeatures: []string{}, + Variant: "testVarient", + }, + }, + } + manService := new(mocks.ManifestService) + manService.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(new(mocks.Manifest), nil) + mockRegClient.On("Manifests", mock.Anything).Return(manService, nil) + _, err := TagInfoFromReferences(&client, opts, log.NewContext(), tagInfo, descriptor) + require.Error(t, err) + }) +} + +func Test_TagMetadata(t *testing.T) { + t.Run("Check for correct error handling when manifest contains no history", func(t *testing.T) { + meta1 := &schema1.SignedManifest{ //nolint:staticcheck + Manifest: schema1.Manifest{ //nolint:staticcheck + History: []schema1.History{}, //nolint:staticcheck + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) + }) + + t.Run("Check for correct error handling when manifest contains invalid history", func(t *testing.T) { + meta1 := &schema1.SignedManifest{ //nolint:staticcheck + Manifest: schema1.Manifest{ //nolint:staticcheck + History: []schema1.History{ //nolint:staticcheck + { + V1Compatibility: `{"created": {"something": "notastring"}}`, + }, + }, + }, + } + + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) + }) + + t.Run("Check for correct error handling when manifest contains invalid history", func(t *testing.T) { + meta1 := &schema1.SignedManifest{ //nolint:staticcheck + Manifest: schema1.Manifest{ //nolint:staticcheck + History: []schema1.History{ //nolint:staticcheck + { + V1Compatibility: `{"something": "something"}`, + }, + }, + }, + } + + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) + + }) + + t.Run("Check for invalid/valid timestamp and non-match platforms", func(t *testing.T) { + ts := "invalid" + meta1 := &schema1.SignedManifest{ //nolint:staticcheck + Manifest: schema1.Manifest{ //nolint:staticcheck + History: []schema1.History{ //nolint:staticcheck + { + V1Compatibility: `{"created":"` + ts + `"}`, + }, + }, + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) + + ts = time.Now().Format(time.RFC3339Nano) + opts := &options.ManifestOptions{} + meta1.Manifest.History[0].V1Compatibility = `{"created":"` + ts + `"}` + tagInfo, _ := client.TagMetadata(meta1, opts) + assert.Equal(t, ts, tagInfo.CreatedAt.Format(time.RFC3339Nano)) + + opts.WithPlatform("testOS", "testArch", "testVariant") + tagInfo, err = client.TagMetadata(meta1, opts) + assert.Nil(t, tagInfo) + assert.Nil(t, err) + }) +} + +func Test_TagMetadata_2(t *testing.T) { + t.Run("ocischema DeserializedManifest invalid digest format", func(t *testing.T) { + meta1 := &ocischema.DeserializedManifest{ + Manifest: ocischema.Manifest{ + Versioned: manifest.Versioned{ + SchemaVersion: 1, + MediaType: "", + }, + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + + require.NoError(t, err) + err = client.NewRepository("test/test") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) // invalid digest format + }) + t.Run("schema2 DeserializedManifest invalid digest format", func(t *testing.T) { + meta1 := &schema2.DeserializedManifest{ + Manifest: schema2.Manifest{ + Versioned: manifest.Versioned{ + SchemaVersion: 1, + MediaType: "", + }, + Config: distribution.Descriptor{ + MediaType: "", + Digest: "sha256:abc", + }, + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + + require.NoError(t, err) + err = client.NewRepository("test/test") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) // invalid digest format + }) + t.Run("ocischema DeserializedImageIndex empty index not supported", func(t *testing.T) { + meta1 := &ocischema.DeserializedImageIndex{ + ImageIndex: ocischema.ImageIndex{ + Versioned: manifest.Versioned{ + SchemaVersion: 1, + MediaType: "", + }, + Manifests: nil, + Annotations: nil, + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + + require.NoError(t, err) + err = client.NewRepository("test/test") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) // empty index not supported + }) + t.Run("ocischema DeserializedImageIndex empty manifestlist not supported", func(t *testing.T) { + meta1 := &manifestlist.DeserializedManifestList{ + ManifestList: manifestlist.ManifestList{ + Versioned: manifest.Versioned{ + SchemaVersion: 1, + MediaType: "", + }, + Manifests: nil, + }, + } + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + client, err := NewClient(ep, "", "") + + require.NoError(t, err) + err = client.NewRepository("test/test") + require.NoError(t, err) + _, err = client.TagMetadata(meta1, &options.ManifestOptions{}) + require.Error(t, err) // empty manifestlist not supported + }) +} + +func TestPing(t *testing.T) { + t.Run("fail ping", func(t *testing.T) { + mockManager := new(mocks.Manager) + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + mockManager.On("AddResponse", mock.Anything).Return(fmt.Errorf("fail ping")) + _, err = ping(mockManager, ep, "") + require.Error(t, err) + }) + + t.Run("success ping", func(t *testing.T) { + mockManager := new(mocks.Manager) + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + mockManager.On("AddResponse", mock.Anything).Return(nil) + _, err = ping(mockManager, ep, "") + require.NoError(t, err) + }) + + t.Run("Invalid Docker Registry", func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + mockManager := new(mocks.Manager) + ep := &RegistryEndpoint{RegistryAPI: testServer.URL} + mockManager.On("AddResponse", mock.Anything).Return(nil) + _, err := ping(mockManager, ep, "") + require.Error(t, err) + assert.ErrorContains(t, err, "does not seem to be a valid v2 Docker Registry API") + }) + + t.Run("Empty Registry API", func(t *testing.T) { + mockManager := new(mocks.Manager) + ep := &RegistryEndpoint{RegistryAPI: ""} + mockManager.On("AddResponse", mock.Anything).Return(nil) + _, err := ping(mockManager, ep, "") + require.Error(t, err) + assert.ErrorContains(t, err, "unsupported protocol scheme") + }) + +} diff --git a/registry-scanner/pkg/registry/config.go b/registry-scanner/pkg/registry/config.go new file mode 100644 index 0000000..641c598 --- /dev/null +++ b/registry-scanner/pkg/registry/config.go @@ -0,0 +1,139 @@ +package registry + +import ( + "fmt" + "os" + "time" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log" + + "gopkg.in/yaml.v2" +) + +// RegistryConfiguration represents a single repository configuration for being +// unmarshaled from YAML. +type RegistryConfiguration struct { + Name string `yaml:"name"` + ApiURL string `yaml:"api_url"` + Ping bool `yaml:"ping,omitempty"` + Credentials string `yaml:"credentials,omitempty"` + CredsExpire time.Duration `yaml:"credsexpire,omitempty"` + TagSortMode string `yaml:"tagsortmode,omitempty"` + Prefix string `yaml:"prefix,omitempty"` + Insecure bool `yaml:"insecure,omitempty"` + DefaultNS string `yaml:"defaultns,omitempty"` + Limit int `yaml:"limit,omitempty"` + IsDefault bool `yaml:"default,omitempty"` +} + +// RegistryList contains multiple RegistryConfiguration items +type RegistryList struct { + Items []RegistryConfiguration `yaml:"registries"` +} + +func clearRegistries() { + registryLock.Lock() + registries = make(map[string]*RegistryEndpoint) + registryLock.Unlock() +} + +// LoadRegistryConfiguration loads a YAML-formatted registry configuration from +// a given file at path. +func LoadRegistryConfiguration(path string, clear bool) error { + registryBytes, err := os.ReadFile(path) + if err != nil { + return err + } + registryList, err := ParseRegistryConfiguration(string(registryBytes)) + if err != nil { + return err + } + + if clear { + clearRegistries() + } + + haveDefault := false + + for _, reg := range registryList.Items { + tagSortMode := TagListSortFromString(reg.TagSortMode) + if tagSortMode != TagListSortUnsorted { + log.Warnf("Registry %s has tag sort mode set to %s, meta data retrieval will be disabled for this registry.", reg.ApiURL, tagSortMode) + } + ep := NewRegistryEndpoint(reg.Prefix, reg.Name, reg.ApiURL, reg.Credentials, reg.DefaultNS, reg.Insecure, tagSortMode, reg.Limit, reg.CredsExpire) + if reg.IsDefault { + if haveDefault { + dep := GetDefaultRegistry() + if dep == nil { + panic("unexpected: default registry should be set, but is not") + } + return fmt.Errorf("cannot set registry %s as default - only one default registry allowed, currently set to %s", ep.RegistryPrefix, dep.RegistryPrefix) + } + } + + if err := AddRegistryEndpoint(ep); err != nil { + return err + } + + if reg.IsDefault { + SetDefaultRegistry(ep) + haveDefault = true + } + } + + log.Infof("Loaded %d registry configurations from %s", len(registryList.Items), path) + return nil +} + +// Parses a registry configuration from a YAML input string and returns a list +// of registries. +func ParseRegistryConfiguration(yamlSource string) (RegistryList, error) { + var regList RegistryList + var defaultPrefixFound = "" + err := yaml.UnmarshalStrict([]byte(yamlSource), ®List) + if err != nil { + return RegistryList{}, err + } + + // validate the parsed list + for _, registry := range regList.Items { + if registry.Name == "" { + err = fmt.Errorf("registry name is missing for entry %v", registry) + } else if registry.ApiURL == "" { + err = fmt.Errorf("API URL must be specified for registry %s", registry.Name) + } else if registry.Prefix == "" { + if defaultPrefixFound != "" { + err = fmt.Errorf("there must be only one default registry (already is %s), %s needs a prefix", defaultPrefixFound, registry.Name) + } else { + defaultPrefixFound = registry.Name + } + } + + if err == nil { + if tls := TagListSortFromString(registry.TagSortMode); tls == TagListSortUnknown { + err = fmt.Errorf("unknown tag sort mode for registry %s: %s", registry.Name, registry.TagSortMode) + } + } + } + + if err != nil { + return RegistryList{}, err + } + + return regList, nil +} + +// RestRestoreDefaultRegistryConfiguration restores the registry configuration +// to the default values. +func RestoreDefaultRegistryConfiguration() { + registryLock.Lock() + defer registryLock.Unlock() + defaultRegistry = nil + registries = make(map[string]*RegistryEndpoint) + for k, v := range registryTweaks { + registries[k] = v.DeepCopy() + if v.IsDefault { + SetDefaultRegistry(registries[k]) + } + } +} diff --git a/registry-scanner/pkg/registry/config_test.go b/registry-scanner/pkg/registry/config_test.go new file mode 100644 index 0000000..5664ecf --- /dev/null +++ b/registry-scanner/pkg/registry/config_test.go @@ -0,0 +1,110 @@ +package registry + +import ( + "testing" + "time" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/test/fixture" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseRegistryConfFromYaml(t *testing.T) { + t.Run("Parse from valid YAML", func(t *testing.T) { + data := fixture.MustReadFile("../../config/example-config.yaml") + regList, err := ParseRegistryConfiguration(data) + require.NoError(t, err) + assert.Len(t, regList.Items, 4) + }) + + t.Run("Parse from invalid YAML: no name found", func(t *testing.T) { + registries := ` +registries: +- api_url: https://foo.io + ping: false +` + regList, err := ParseRegistryConfiguration(registries) + require.Error(t, err) + assert.Contains(t, err.Error(), "name is missing") + assert.Len(t, regList.Items, 0) + }) + + t.Run("Parse from invalid YAML: no API URL found", func(t *testing.T) { + registries := ` +registries: +- name: Foobar Registry + ping: false +` + regList, err := ParseRegistryConfiguration(registries) + require.Error(t, err) + assert.Contains(t, err.Error(), "API URL must be") + assert.Len(t, regList.Items, 0) + }) + + t.Run("Parse from invalid YAML: multiple registries without prefix", func(t *testing.T) { + registries := ` +registries: +- name: Foobar Registry + api_url: https://foobar.io + ping: false +- name: Barbar Registry + api_url: https://barbar.io + ping: false +` + regList, err := ParseRegistryConfiguration(registries) + require.Error(t, err) + assert.Contains(t, err.Error(), "already is Foobar Registry") + assert.Len(t, regList.Items, 0) + }) + + t.Run("Parse from invalid YAML: invalid tag sort mode", func(t *testing.T) { + registries := ` +registries: +- name: Foobar Registry + api_url: https://foobar.io + ping: false + tagsortmode: invalid +` + regList, err := ParseRegistryConfiguration(registries) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown tag sort mode") + assert.Len(t, regList.Items, 0) + }) + +} + +func Test_LoadRegistryConfiguration(t *testing.T) { + RestoreDefaultRegistryConfiguration() + + t.Run("Load from valid location", func(t *testing.T) { + err := LoadRegistryConfiguration("../../config/example-config.yaml", true) + require.NoError(t, err) + assert.Len(t, registries, 4) + reg, err := GetRegistryEndpoint("gcr.io") + require.NoError(t, err) + assert.Equal(t, "pullsecret:foo/bar", reg.Credentials) + reg, err = GetRegistryEndpoint("ghcr.io") + require.NoError(t, err) + assert.Equal(t, "ext:/some/script", reg.Credentials) + assert.Equal(t, 5*time.Hour, reg.CredsExpire) + RestoreDefaultRegistryConfiguration() + reg, err = GetRegistryEndpoint("gcr.io") + require.NoError(t, err) + assert.Equal(t, "", reg.Credentials) + }) + + t.Run("Load from invalid location", func(t *testing.T) { + err := LoadRegistryConfiguration("../../test/testdata/registry/config/does-not-exist.yaml", true) + require.Error(t, err) + require.Contains(t, err.Error(), "no such file or directory") + }) + + t.Run("Two defaults defined in same config", func(t *testing.T) { + err := LoadRegistryConfiguration("../../test/testdata/registry/config/two-defaults.yaml", true) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot set registry") + }) + + RestoreDefaultRegistryConfiguration() +} diff --git a/registry-scanner/pkg/registry/endpoints.go b/registry-scanner/pkg/registry/endpoints.go new file mode 100644 index 0000000..3b64fc5 --- /dev/null +++ b/registry-scanner/pkg/registry/endpoints.go @@ -0,0 +1,305 @@ +package registry + +import ( + "crypto/tls" + "fmt" + "math" + "net/http" + "strings" + "sync" + "time" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/cache" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log" + + "go.uber.org/ratelimit" +) + +// TagListSort defines how the registry returns the list of tags +type TagListSort int + +const ( + TagListSortUnknown TagListSort = -1 + TagListSortUnsorted TagListSort = 0 + TagListSortLatestFirst TagListSort = 1 + TagListSortLatestLast TagListSort = 2 + TagListSortUnsortedString string = "unsorted" + TagListSortLatestFirstString string = "latest-first" + TagListSortLatestLastString string = "latest-last" + TagListSortUnknownString string = "unknown" +) + +const ( + RateLimitNone = math.MaxInt32 + RateLimitDefault = 10 +) + +// IsTimeSorted returns whether a tag list is time sorted +func (tls TagListSort) IsTimeSorted() bool { + return tls == TagListSortLatestFirst || tls == TagListSortLatestLast +} + +// TagListSortFromString gets the TagListSort value from a given string +func TagListSortFromString(tls string) TagListSort { + switch strings.ToLower(tls) { + case "latest-first": + return TagListSortLatestFirst + case "latest-last": + return TagListSortLatestLast + case "none", "": + return TagListSortUnsorted + default: + log.Warnf("unknown tag list sort mode: %s", tls) + return TagListSortUnknown + } +} + +// String returns the string representation of a TagListSort value +func (tls TagListSort) String() string { + switch tls { + case TagListSortLatestFirst: + return TagListSortLatestFirstString + case TagListSortLatestLast: + return TagListSortLatestLastString + case TagListSortUnsorted: + return TagListSortUnsortedString + } + + return TagListSortUnknownString +} + +// RegistryEndpoint holds information on how to access any specific registry API +// endpoint. +type RegistryEndpoint struct { + RegistryName string + RegistryPrefix string + RegistryAPI string + Username string + Password string + Ping bool + Credentials string + Insecure bool + DefaultNS string + CredsExpire time.Duration + CredsUpdated time.Time + TagListSort TagListSort + Cache cache.ImageTagCache + Limiter ratelimit.Limiter + IsDefault bool + lock sync.RWMutex + limit int +} + +// registryTweaks should contain a list of registries whose settings cannot be +// inferred by just looking at the image prefix. Prominent example here is the +// Docker Hub registry, which is referred to as docker.io from the image, but +// its API endpoint is https://registry-1.docker.io (and not https://docker.io) +var registryTweaks map[string]*RegistryEndpoint = map[string]*RegistryEndpoint{ + "docker.io": { + RegistryName: "Docker Hub", + RegistryPrefix: "docker.io", + RegistryAPI: "https://registry-1.docker.io", + Ping: true, + Insecure: false, + DefaultNS: "library", + Cache: cache.NewMemCache(), + Limiter: ratelimit.New(RateLimitDefault), + IsDefault: true, + }, +} + +var registries map[string]*RegistryEndpoint = make(map[string]*RegistryEndpoint) + +// Default registry points to the registry that is to be used as the default, +// e.g. when no registry prefix is given for a certain image. +var defaultRegistry *RegistryEndpoint + +// Simple RW mutex for concurrent access to registries map +var registryLock sync.RWMutex + +func AddRegistryEndpointFromConfig(epc RegistryConfiguration) error { + ep := NewRegistryEndpoint(epc.Prefix, epc.Name, epc.ApiURL, epc.Credentials, epc.DefaultNS, epc.Insecure, TagListSortFromString(epc.TagSortMode), epc.Limit, epc.CredsExpire) + return AddRegistryEndpoint(ep) +} + +// NewRegistryEndpoint returns an endpoint object with the given configuration +// pre-populated and a fresh cache. +func NewRegistryEndpoint(prefix, name, apiUrl, credentials, defaultNS string, insecure bool, tagListSort TagListSort, limit int, credsExpire time.Duration) *RegistryEndpoint { + if limit <= 0 { + limit = RateLimitNone + } + ep := &RegistryEndpoint{ + RegistryName: name, + RegistryPrefix: prefix, + RegistryAPI: strings.TrimSuffix(apiUrl, "/"), + Credentials: credentials, + CredsExpire: credsExpire, + Cache: cache.NewMemCache(), + Insecure: insecure, + DefaultNS: defaultNS, + TagListSort: tagListSort, + Limiter: ratelimit.New(limit), + limit: limit, + } + return ep +} + +// AddRegistryEndpoint adds registry endpoint information with the given details +func AddRegistryEndpoint(ep *RegistryEndpoint) error { + prefix := ep.RegistryPrefix + + registryLock.Lock() + // If the endpoint is supposed to be the default endpoint, make sure that + // any previously set default endpoint is unset. + if ep.IsDefault { + if dep := GetDefaultRegistry(); dep != nil { + dep.IsDefault = false + } + SetDefaultRegistry(ep) + } + registries[prefix] = ep + registryLock.Unlock() + + logCtx := log.WithContext() + logCtx.AddField("registry", ep.RegistryAPI) + logCtx.AddField("prefix", ep.RegistryPrefix) + if ep.limit != RateLimitNone { + logCtx.Debugf("setting rate limit to %d requests per second", ep.limit) + } else { + logCtx.Debugf("rate limiting is disabled") + } + return nil +} + +// inferRegistryEndpointFromPrefix returns a registry endpoint with the API +// URL inferred from the prefix and adds it to the list of the configured +// registries. +func inferRegistryEndpointFromPrefix(prefix string) *RegistryEndpoint { + apiURL := "https://" + prefix + return NewRegistryEndpoint(prefix, prefix, apiURL, "", "", false, TagListSortUnsorted, 20, 0) +} + +// GetRegistryEndpoint retrieves the endpoint information for the given prefix +func GetRegistryEndpoint(prefix string) (*RegistryEndpoint, error) { + if prefix == "" { + if defaultRegistry == nil { + return nil, fmt.Errorf("no default endpoint configured") + } else { + return defaultRegistry, nil + } + } + + registryLock.RLock() + registry, ok := registries[prefix] + registryLock.RUnlock() + + if ok { + return registry, nil + } else { + var err error + ep := inferRegistryEndpointFromPrefix(prefix) + if ep != nil { + err = AddRegistryEndpoint(ep) + } else { + err = fmt.Errorf("could not infer registry configuration from prefix %s", prefix) + } + if err == nil { + log.Debugf("Inferred registry from prefix %s to use API %s", prefix, ep.RegistryAPI) + } + return ep, err + } +} + +// SetDefaultRegistry sets a given registry endpoint as the default +func SetDefaultRegistry(ep *RegistryEndpoint) { + log.Debugf("Setting default registry endpoint to %s", ep.RegistryPrefix) + ep.IsDefault = true + if defaultRegistry != nil { + log.Debugf("Previous default registry was %s", defaultRegistry.RegistryPrefix) + defaultRegistry.IsDefault = false + } + defaultRegistry = ep +} + +// GetDefaultRegistry returns the registry endpoint that is set as default, +// or nil if no default registry endpoint is set +func GetDefaultRegistry() *RegistryEndpoint { + if defaultRegistry != nil { + log.Debugf("Getting default registry endpoint: %s", defaultRegistry.RegistryPrefix) + } else { + log.Debugf("No default registry defined.") + } + return defaultRegistry +} + +// SetRegistryEndpointCredentials allows to change the credentials used for +// endpoint access for existing RegistryEndpoint configuration +func SetRegistryEndpointCredentials(prefix, credentials string) error { + registry, err := GetRegistryEndpoint(prefix) + if err != nil { + return err + } + registry.lock.Lock() + registry.Credentials = credentials + registry.lock.Unlock() + return nil +} + +// ConfiguredEndpoints returns a list of prefixes that are configured +func ConfiguredEndpoints() []string { + registryLock.RLock() + defer registryLock.RUnlock() + r := make([]string, 0, len(registries)) + for _, v := range registries { + r = append(r, v.RegistryPrefix) + } + return r +} + +// DeepCopy copies the endpoint to a new object, but creating a new Cache +func (ep *RegistryEndpoint) DeepCopy() *RegistryEndpoint { + ep.lock.RLock() + newEp := &RegistryEndpoint{} + newEp.RegistryAPI = ep.RegistryAPI + newEp.RegistryName = ep.RegistryName + newEp.RegistryPrefix = ep.RegistryPrefix + newEp.Credentials = ep.Credentials + newEp.Ping = ep.Ping + newEp.TagListSort = ep.TagListSort + newEp.Cache = cache.NewMemCache() + newEp.Insecure = ep.Insecure + newEp.DefaultNS = ep.DefaultNS + newEp.Limiter = ep.Limiter + newEp.CredsExpire = ep.CredsExpire + newEp.CredsUpdated = ep.CredsUpdated + newEp.IsDefault = ep.IsDefault + newEp.limit = ep.limit + ep.lock.RUnlock() + return newEp +} + +// GetTransport returns a transport object for this endpoint +func (ep *RegistryEndpoint) GetTransport() *http.Transport { + tlsC := &tls.Config{} + if ep.Insecure { + tlsC.InsecureSkipVerify = true + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: tlsC, + } +} + +// init initializes the registry configuration +func init() { + for k, v := range registryTweaks { + registries[k] = v.DeepCopy() + if v.IsDefault { + if defaultRegistry == nil { + defaultRegistry = v + } else { + panic("only one default registry can be configured") + } + } + } +} diff --git a/registry-scanner/pkg/registry/endpoints_test.go b/registry-scanner/pkg/registry/endpoints_test.go new file mode 100644 index 0000000..d2dae1f --- /dev/null +++ b/registry-scanner/pkg/registry/endpoints_test.go @@ -0,0 +1,354 @@ +package registry + +import ( + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInferRegistryEndpointFromPrefix(t *testing.T) { + prefix := "example.com" + expectedAPIURL := "https://" + prefix + endpoint := inferRegistryEndpointFromPrefix(prefix) + assert.NotNil(t, endpoint) + assert.Equal(t, prefix, endpoint.RegistryName) + assert.Equal(t, prefix, endpoint.RegistryPrefix) + assert.Equal(t, expectedAPIURL, endpoint.RegistryAPI) + assert.Equal(t, TagListSortUnsorted, endpoint.TagListSort) + assert.Equal(t, 20, endpoint.limit) + assert.False(t, endpoint.Insecure) +} + +func TestNewRegistryEndpoint(t *testing.T) { + prefix := "example.com" + name := "exampleRegistry" + apiUrl := "https://api.example.com" + credentials := "user:pass" + defaultNS := "default" + insecure := true + tagListSort := TagListSortLatestFirst + limit := 10 + credsExpire := time.Minute * 30 + endpoint := NewRegistryEndpoint(prefix, name, apiUrl, credentials, defaultNS, insecure, tagListSort, limit, credsExpire) + assert.NotNil(t, endpoint) + assert.Equal(t, name, endpoint.RegistryName) + assert.Equal(t, prefix, endpoint.RegistryPrefix) + assert.Equal(t, strings.TrimSuffix(apiUrl, "/"), endpoint.RegistryAPI) + assert.Equal(t, credentials, endpoint.Credentials) + assert.Equal(t, credsExpire, endpoint.CredsExpire) + assert.Equal(t, insecure, endpoint.Insecure) + assert.Equal(t, defaultNS, endpoint.DefaultNS) + assert.Equal(t, tagListSort, endpoint.TagListSort) + assert.Equal(t, limit, endpoint.limit) +} + +func TestTagListSortFromString(t *testing.T) { + t.Run("returns TagListSortLatestFirst for 'latest-first'", func(t *testing.T) { + result := TagListSortFromString("latest-first") + assert.Equal(t, TagListSortLatestFirst, result) + }) + + t.Run("returns TagListSortLatestLast for 'latest-last'", func(t *testing.T) { + result := TagListSortFromString("latest-last") + assert.Equal(t, TagListSortLatestLast, result) + }) + + t.Run("returns TagListSortUnsorted for 'none'", func(t *testing.T) { + result := TagListSortFromString("none") + assert.Equal(t, TagListSortUnsorted, result) + }) + + t.Run("returns TagListSortUnsorted for an empty string", func(t *testing.T) { + result := TagListSortFromString("") + assert.Equal(t, TagListSortUnsorted, result) + }) + + t.Run("returns TagListSortUnknown for an unknown value", func(t *testing.T) { + result := TagListSortFromString("unknown-value") + assert.Equal(t, TagListSortUnknown, result) + }) +} + +func TestIsTimeSorted(t *testing.T) { + t.Run("returns true for TagListSortLatestFirst", func(t *testing.T) { + assert.True(t, TagListSortLatestFirst.IsTimeSorted()) + }) + t.Run("returns true for TagListSortLatestLast", func(t *testing.T) { + assert.True(t, TagListSortLatestLast.IsTimeSorted()) + }) + t.Run("returns false for TagListSortUnsorted", func(t *testing.T) { + assert.False(t, TagListSortUnsorted.IsTimeSorted()) + }) + t.Run("returns false for TagListSortUnknown", func(t *testing.T) { + assert.False(t, TagListSortUnknown.IsTimeSorted()) + }) +} + +func TestTagListSort_String(t *testing.T) { + t.Run("returns 'latest-first' for TagListSortLatestFirst", func(t *testing.T) { + assert.Equal(t, TagListSortLatestFirstString, TagListSortLatestFirst.String()) + }) + + t.Run("returns 'latest-last' for TagListSortLatestLast", func(t *testing.T) { + assert.Equal(t, TagListSortLatestLastString, TagListSortLatestLast.String()) + }) + + t.Run("returns 'unsorted' for TagListSortUnsorted", func(t *testing.T) { + assert.Equal(t, TagListSortUnsortedString, TagListSortUnsorted.String()) + }) + + t.Run("returns 'unknown' for TagListSortUnknown", func(t *testing.T) { + assert.Equal(t, TagListSortUnknownString, TagListSortUnknown.String()) + }) + + t.Run("returns 'unknown' for an undefined TagListSort value", func(t *testing.T) { + var undefined TagListSort = 99 + assert.Equal(t, TagListSortUnknownString, undefined.String()) + }) +} + +func Test_GetEndpoints(t *testing.T) { + RestoreDefaultRegistryConfiguration() + + t.Run("Get default endpoint", func(t *testing.T) { + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, "docker.io", ep.RegistryPrefix) + }) + + t.Run("Get GCR endpoint", func(t *testing.T) { + ep, err := GetRegistryEndpoint("gcr.io") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, ep.RegistryPrefix, "gcr.io") + }) + + t.Run("Infer endpoint", func(t *testing.T) { + ep, err := GetRegistryEndpoint("foobar.com") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, "foobar.com", ep.RegistryPrefix) + assert.Equal(t, "https://foobar.com", ep.RegistryAPI) + }) +} + +func Test_AddEndpoint(t *testing.T) { + RestoreDefaultRegistryConfiguration() + + t.Run("Add new endpoint", func(t *testing.T) { + err := AddRegistryEndpoint(NewRegistryEndpoint("example.com", "Example", "https://example.com", "", "", false, TagListSortUnsorted, 5, 0)) + require.NoError(t, err) + }) + t.Run("Get example.com endpoint", func(t *testing.T) { + ep, err := GetRegistryEndpoint("example.com") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, ep.RegistryPrefix, "example.com") + assert.Equal(t, ep.RegistryName, "Example") + assert.Equal(t, ep.RegistryAPI, "https://example.com") + assert.Equal(t, ep.Insecure, false) + assert.Equal(t, ep.DefaultNS, "") + assert.Equal(t, ep.TagListSort, TagListSortUnsorted) + }) + t.Run("Change existing endpoint", func(t *testing.T) { + err := AddRegistryEndpoint(NewRegistryEndpoint("example.com", "Example", "https://example.com", "", "library", true, TagListSortLatestFirst, 5, 0)) + require.NoError(t, err) + ep, err := GetRegistryEndpoint("example.com") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, ep.Insecure, true) + assert.Equal(t, ep.DefaultNS, "library") + assert.Equal(t, ep.TagListSort, TagListSortLatestFirst) + }) +} + +func Test_SetEndpointCredentials(t *testing.T) { + RestoreDefaultRegistryConfiguration() + + t.Run("Set credentials on default registry", func(t *testing.T) { + err := SetRegistryEndpointCredentials("", "env:FOOBAR") + require.NoError(t, err) + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, ep.Credentials, "env:FOOBAR") + }) + + t.Run("Unset credentials on default registry", func(t *testing.T) { + err := SetRegistryEndpointCredentials("", "") + require.NoError(t, err) + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + require.NotNil(t, ep) + assert.Equal(t, ep.Credentials, "") + }) +} + +func Test_EndpointConcurrentAccess(t *testing.T) { + RestoreDefaultRegistryConfiguration() + const numRuns = 50 + // Make sure we're not deadlocking on read + t.Run("Concurrent read access", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(numRuns) + for i := 0; i < numRuns; i++ { + go func() { + ep, err := GetRegistryEndpoint("gcr.io") + require.NoError(t, err) + require.NotNil(t, ep) + wg.Done() + }() + } + wg.Wait() + }) + + // Make sure we're not deadlocking on write + t.Run("Concurrent write access", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(numRuns) + for i := 0; i < numRuns; i++ { + go func(i int) { + creds := fmt.Sprintf("secret:foo/secret-%d", i) + err := SetRegistryEndpointCredentials("", creds) + require.NoError(t, err) + ep, err := GetRegistryEndpoint("") + require.NoError(t, err) + require.NotNil(t, ep) + wg.Done() + }(i) + } + wg.Wait() + }) +} + +func Test_SetDefault(t *testing.T) { + RestoreDefaultRegistryConfiguration() + + dep := GetDefaultRegistry() + require.NotNil(t, dep) + assert.Equal(t, "docker.io", dep.RegistryPrefix) + assert.True(t, dep.IsDefault) + + ep, err := GetRegistryEndpoint("ghcr.io") + require.NoError(t, err) + require.NotNil(t, ep) + require.False(t, ep.IsDefault) + + SetDefaultRegistry(ep) + assert.True(t, ep.IsDefault) + assert.False(t, dep.IsDefault) + require.NotNil(t, GetDefaultRegistry()) + assert.Equal(t, ep.RegistryPrefix, GetDefaultRegistry().RegistryPrefix) +} + +func Test_DeepCopy(t *testing.T) { + t.Run("DeepCopy endpoint object", func(t *testing.T) { + ep, err := GetRegistryEndpoint("docker.pkg.github.com") + require.NoError(t, err) + require.NotNil(t, ep) + newEp := ep.DeepCopy() + assert.Equal(t, ep.RegistryAPI, newEp.RegistryAPI) + assert.Equal(t, ep.RegistryName, newEp.RegistryName) + assert.Equal(t, ep.RegistryPrefix, newEp.RegistryPrefix) + assert.Equal(t, ep.Credentials, newEp.Credentials) + assert.Equal(t, ep.TagListSort, newEp.TagListSort) + assert.Equal(t, ep.Username, newEp.Username) + assert.Equal(t, ep.Ping, newEp.Ping) + }) +} + +func Test_GetTagListSortFromString(t *testing.T) { + t.Run("Get latest-first sorting", func(t *testing.T) { + tls := TagListSortFromString("latest-first") + assert.Equal(t, TagListSortLatestFirst, tls) + }) + t.Run("Get latest-last sorting", func(t *testing.T) { + tls := TagListSortFromString("latest-last") + assert.Equal(t, TagListSortLatestLast, tls) + }) + t.Run("Get none sorting explicit", func(t *testing.T) { + tls := TagListSortFromString("none") + assert.Equal(t, TagListSortUnsorted, tls) + }) + t.Run("Get none sorting implicit", func(t *testing.T) { + tls := TagListSortFromString("") + assert.Equal(t, TagListSortUnsorted, tls) + }) + t.Run("Get unknown sorting from unknown string", func(t *testing.T) { + tls := TagListSortFromString("unknown") + assert.Equal(t, TagListSortUnknown, tls) + }) +} + +func TestGetTransport(t *testing.T) { + t.Run("returns transport with default TLS config when Insecure is false", func(t *testing.T) { + endpoint := &RegistryEndpoint{ + Insecure: false, + } + transport := endpoint.GetTransport() + + assert.NotNil(t, transport) + assert.NotNil(t, transport.TLSClientConfig) + assert.False(t, transport.TLSClientConfig.InsecureSkipVerify) + }) + + t.Run("returns transport with insecure TLS config when Insecure is true", func(t *testing.T) { + endpoint := &RegistryEndpoint{ + Insecure: true, + } + transport := endpoint.GetTransport() + + assert.NotNil(t, transport) + assert.NotNil(t, transport.TLSClientConfig) + assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + }) +} + +func Test_RestoreDefaultRegistryConfiguration(t *testing.T) { + // Call the function to restore default configuration + RestoreDefaultRegistryConfiguration() + + // Retrieve the default registry endpoint + defaultEp := GetDefaultRegistry() + + // Validate that the default registry endpoint is not nil + require.NotNil(t, defaultEp) + + // Validate that the default registry endpoint has expected properties + assert.Equal(t, "docker.io", defaultEp.RegistryPrefix) + assert.True(t, defaultEp.IsDefault) +} + +func TestConfiguredEndpoints(t *testing.T) { + // Test the function + endpoints := ConfiguredEndpoints() + // Validate the output + expected := []string{"docker.io"} + require.Len(t, endpoints, len(expected), "The number of endpoints should match the expected number") + assert.ElementsMatch(t, expected, endpoints, "The endpoints should match the expected values") + +} + +func TestAddRegistryEndpointFromConfig(t *testing.T) { + t.Run("successfully adds registry endpoint from config", func(t *testing.T) { + config := RegistryConfiguration{ + Prefix: "example.com", + Name: "exampleRegistry", + ApiURL: "https://api.example.com", + Credentials: "user:pass", + DefaultNS: "default", + Insecure: true, + TagSortMode: "latest-first", + Limit: 10, + CredsExpire: time.Minute * 30, + } + err := AddRegistryEndpointFromConfig(config) + require.NoError(t, err) + }) +} diff --git a/registry-scanner/pkg/registry/mocks/Limiter.go b/registry-scanner/pkg/registry/mocks/Limiter.go new file mode 100644 index 0000000..81dbb19 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/Limiter.go @@ -0,0 +1,46 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// Limiter is an autogenerated mock type for the Limiter type +type Limiter struct { + mock.Mock +} + +// Take provides a mock function with given fields: +func (_m *Limiter) Take() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Take") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// NewLimiter creates a new instance of Limiter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewLimiter(t interface { + mock.TestingT + Cleanup(func()) +}) *Limiter { + mock := &Limiter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/Manager.go b/registry-scanner/pkg/registry/mocks/Manager.go new file mode 100644 index 0000000..02c3776 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/Manager.go @@ -0,0 +1,80 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + http "net/http" + + challenge "github.com/distribution/distribution/v3/registry/client/auth/challenge" + + mock "github.com/stretchr/testify/mock" + + url "net/url" +) + +// Manager is an autogenerated mock type for the Manager type +type Manager struct { + mock.Mock +} + +// AddResponse provides a mock function with given fields: resp +func (_m *Manager) AddResponse(resp *http.Response) error { + ret := _m.Called(resp) + + if len(ret) == 0 { + panic("no return value specified for AddResponse") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*http.Response) error); ok { + r0 = rf(resp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetChallenges provides a mock function with given fields: endpoint +func (_m *Manager) GetChallenges(endpoint url.URL) ([]challenge.Challenge, error) { + ret := _m.Called(endpoint) + + if len(ret) == 0 { + panic("no return value specified for GetChallenges") + } + + var r0 []challenge.Challenge + var r1 error + if rf, ok := ret.Get(0).(func(url.URL) ([]challenge.Challenge, error)); ok { + return rf(endpoint) + } + if rf, ok := ret.Get(0).(func(url.URL) []challenge.Challenge); ok { + r0 = rf(endpoint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]challenge.Challenge) + } + } + + if rf, ok := ret.Get(1).(func(url.URL) error); ok { + r1 = rf(endpoint) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewManager creates a new instance of Manager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewManager(t interface { + mock.TestingT + Cleanup(func()) +}) *Manager { + mock := &Manager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/Manifest.go b/registry-scanner/pkg/registry/mocks/Manifest.go new file mode 100644 index 0000000..8b92f36 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/Manifest.go @@ -0,0 +1,84 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + distribution "github.com/distribution/distribution/v3" + mock "github.com/stretchr/testify/mock" +) + +// Manifest is an autogenerated mock type for the Manifest type +type Manifest struct { + mock.Mock +} + +// Payload provides a mock function with given fields: +func (_m *Manifest) Payload() (string, []byte, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Payload") + } + + var r0 string + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func() (string, []byte, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func() []byte); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func() error); ok { + r2 = rf() + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// References provides a mock function with given fields: +func (_m *Manifest) References() []distribution.Descriptor { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for References") + } + + var r0 []distribution.Descriptor + if rf, ok := ret.Get(0).(func() []distribution.Descriptor); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]distribution.Descriptor) + } + } + + return r0 +} + +// NewManifest creates a new instance of Manifest. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewManifest(t interface { + mock.TestingT + Cleanup(func()) +}) *Manifest { + mock := &Manifest{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/ManifestService.go b/registry-scanner/pkg/registry/mocks/ManifestService.go new file mode 100644 index 0000000..cb42e32 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/ManifestService.go @@ -0,0 +1,149 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + distribution "github.com/distribution/distribution/v3" + digest "github.com/opencontainers/go-digest" + + mock "github.com/stretchr/testify/mock" +) + +// ManifestService is an autogenerated mock type for the ManifestService type +type ManifestService struct { + mock.Mock +} + +// Delete provides a mock function with given fields: ctx, dgst +func (_m *ManifestService) Delete(ctx context.Context, dgst digest.Digest) error { + ret := _m.Called(ctx, dgst) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, digest.Digest) error); ok { + r0 = rf(ctx, dgst) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Exists provides a mock function with given fields: ctx, dgst +func (_m *ManifestService) Exists(ctx context.Context, dgst digest.Digest) (bool, error) { + ret := _m.Called(ctx, dgst) + + if len(ret) == 0 { + panic("no return value specified for Exists") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, digest.Digest) (bool, error)); ok { + return rf(ctx, dgst) + } + if rf, ok := ret.Get(0).(func(context.Context, digest.Digest) bool); ok { + r0 = rf(ctx, dgst) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, digest.Digest) error); ok { + r1 = rf(ctx, dgst) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: ctx, dgst, options +func (_m *ManifestService) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) { + _va := make([]interface{}, len(options)) + for _i := range options { + _va[_i] = options[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, dgst) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 distribution.Manifest + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, digest.Digest, ...distribution.ManifestServiceOption) (distribution.Manifest, error)); ok { + return rf(ctx, dgst, options...) + } + if rf, ok := ret.Get(0).(func(context.Context, digest.Digest, ...distribution.ManifestServiceOption) distribution.Manifest); ok { + r0 = rf(ctx, dgst, options...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.Manifest) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, digest.Digest, ...distribution.ManifestServiceOption) error); ok { + r1 = rf(ctx, dgst, options...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Put provides a mock function with given fields: ctx, manifest, options +func (_m *ManifestService) Put(ctx context.Context, manifest distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) { + _va := make([]interface{}, len(options)) + for _i := range options { + _va[_i] = options[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, manifest) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Put") + } + + var r0 digest.Digest + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, distribution.Manifest, ...distribution.ManifestServiceOption) (digest.Digest, error)); ok { + return rf(ctx, manifest, options...) + } + if rf, ok := ret.Get(0).(func(context.Context, distribution.Manifest, ...distribution.ManifestServiceOption) digest.Digest); ok { + r0 = rf(ctx, manifest, options...) + } else { + r0 = ret.Get(0).(digest.Digest) + } + + if rf, ok := ret.Get(1).(func(context.Context, distribution.Manifest, ...distribution.ManifestServiceOption) error); ok { + r1 = rf(ctx, manifest, options...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewManifestService creates a new instance of ManifestService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewManifestService(t interface { + mock.TestingT + Cleanup(func()) +}) *ManifestService { + mock := &ManifestService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/RegistryClient.go b/registry-scanner/pkg/registry/mocks/RegistryClient.go new file mode 100644 index 0000000..2943b04 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/RegistryClient.go @@ -0,0 +1,125 @@ +// Code generated by mockery v1.1.2. DO NOT EDIT. + +package mocks + +import ( + distribution "github.com/distribution/distribution/v3" + digest "github.com/opencontainers/go-digest" + + mock "github.com/stretchr/testify/mock" + + options "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/options" + + tag "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" +) + +// RegistryClient is an autogenerated mock type for the RegistryClient type +type RegistryClient struct { + mock.Mock +} + +// ManifestForDigest provides a mock function with given fields: dgst +func (_m *RegistryClient) ManifestForDigest(dgst digest.Digest) (distribution.Manifest, error) { + ret := _m.Called(dgst) + + var r0 distribution.Manifest + if rf, ok := ret.Get(0).(func(digest.Digest) distribution.Manifest); ok { + r0 = rf(dgst) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.Manifest) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(digest.Digest) error); ok { + r1 = rf(dgst) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ManifestForTag provides a mock function with given fields: tagStr +func (_m *RegistryClient) ManifestForTag(tagStr string) (distribution.Manifest, error) { + ret := _m.Called(tagStr) + + var r0 distribution.Manifest + if rf, ok := ret.Get(0).(func(string) distribution.Manifest); ok { + r0 = rf(tagStr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.Manifest) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(tagStr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewRepository provides a mock function with given fields: nameInRepository +func (_m *RegistryClient) NewRepository(nameInRepository string) error { + ret := _m.Called(nameInRepository) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(nameInRepository) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TagMetadata provides a mock function with given fields: manifest, opts +func (_m *RegistryClient) TagMetadata(manifest distribution.Manifest, opts *options.ManifestOptions) (*tag.TagInfo, error) { + ret := _m.Called(manifest, opts) + + var r0 *tag.TagInfo + if rf, ok := ret.Get(0).(func(distribution.Manifest, *options.ManifestOptions) *tag.TagInfo); ok { + r0 = rf(manifest, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*tag.TagInfo) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(distribution.Manifest, *options.ManifestOptions) error); ok { + r1 = rf(manifest, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Tags provides a mock function with given fields: +func (_m *RegistryClient) Tags() ([]string, error) { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/registry-scanner/pkg/registry/mocks/Repository.go b/registry-scanner/pkg/registry/mocks/Repository.go new file mode 100644 index 0000000..04d9b7c --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/Repository.go @@ -0,0 +1,128 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + distribution "github.com/distribution/distribution/v3" + mock "github.com/stretchr/testify/mock" + + reference "github.com/distribution/distribution/v3/reference" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +// Blobs provides a mock function with given fields: ctx +func (_m *Repository) Blobs(ctx context.Context) distribution.BlobStore { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Blobs") + } + + var r0 distribution.BlobStore + if rf, ok := ret.Get(0).(func(context.Context) distribution.BlobStore); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.BlobStore) + } + } + + return r0 +} + +// Manifests provides a mock function with given fields: ctx, options +func (_m *Repository) Manifests(ctx context.Context, options ...distribution.ManifestServiceOption) (distribution.ManifestService, error) { + _va := make([]interface{}, len(options)) + for _i := range options { + _va[_i] = options[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Manifests") + } + + var r0 distribution.ManifestService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...distribution.ManifestServiceOption) (distribution.ManifestService, error)); ok { + return rf(ctx, options...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...distribution.ManifestServiceOption) distribution.ManifestService); ok { + r0 = rf(ctx, options...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.ManifestService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...distribution.ManifestServiceOption) error); ok { + r1 = rf(ctx, options...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Named provides a mock function with given fields: +func (_m *Repository) Named() reference.Named { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Named") + } + + var r0 reference.Named + if rf, ok := ret.Get(0).(func() reference.Named); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(reference.Named) + } + } + + return r0 +} + +// Tags provides a mock function with given fields: ctx +func (_m *Repository) Tags(ctx context.Context) distribution.TagService { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Tags") + } + + var r0 distribution.TagService + if rf, ok := ret.Get(0).(func(context.Context) distribution.TagService); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.TagService) + } + } + + return r0 +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/RoundTripper.go b/registry-scanner/pkg/registry/mocks/RoundTripper.go new file mode 100644 index 0000000..27e22c2 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/RoundTripper.go @@ -0,0 +1,58 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + http "net/http" + + mock "github.com/stretchr/testify/mock" +) + +// RoundTripper is an autogenerated mock type for the RoundTripper type +type RoundTripper struct { + mock.Mock +} + +// RoundTrip provides a mock function with given fields: _a0 +func (_m *RoundTripper) RoundTrip(_a0 *http.Request) (*http.Response, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for RoundTrip") + } + + var r0 *http.Response + var r1 error + if rf, ok := ret.Get(0).(func(*http.Request) (*http.Response, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(*http.Request) *http.Response); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*http.Response) + } + } + + if rf, ok := ret.Get(1).(func(*http.Request) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewRoundTripper creates a new instance of RoundTripper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRoundTripper(t interface { + mock.TestingT + Cleanup(func()) +}) *RoundTripper { + mock := &RoundTripper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/mocks/TagService.go b/registry-scanner/pkg/registry/mocks/TagService.go new file mode 100644 index 0000000..5037808 --- /dev/null +++ b/registry-scanner/pkg/registry/mocks/TagService.go @@ -0,0 +1,153 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + distribution "github.com/distribution/distribution/v3" + mock "github.com/stretchr/testify/mock" +) + +// TagService is an autogenerated mock type for the TagService type +type TagService struct { + mock.Mock +} + +// All provides a mock function with given fields: ctx +func (_m *TagService) All(ctx context.Context) ([]string, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for All") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]string, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []string); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: ctx, tag +func (_m *TagService) Get(ctx context.Context, tag string) (distribution.Descriptor, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 distribution.Descriptor + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (distribution.Descriptor, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, string) distribution.Descriptor); ok { + r0 = rf(ctx, tag) + } else { + r0 = ret.Get(0).(distribution.Descriptor) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Lookup provides a mock function with given fields: ctx, digest +func (_m *TagService) Lookup(ctx context.Context, digest distribution.Descriptor) ([]string, error) { + ret := _m.Called(ctx, digest) + + if len(ret) == 0 { + panic("no return value specified for Lookup") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, distribution.Descriptor) ([]string, error)); ok { + return rf(ctx, digest) + } + if rf, ok := ret.Get(0).(func(context.Context, distribution.Descriptor) []string); ok { + r0 = rf(ctx, digest) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, distribution.Descriptor) error); ok { + r1 = rf(ctx, digest) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Tag provides a mock function with given fields: ctx, tag, desc +func (_m *TagService) Tag(ctx context.Context, tag string, desc distribution.Descriptor) error { + ret := _m.Called(ctx, tag, desc) + + if len(ret) == 0 { + panic("no return value specified for Tag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, distribution.Descriptor) error); ok { + r0 = rf(ctx, tag, desc) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Untag provides a mock function with given fields: ctx, tag +func (_m *TagService) Untag(ctx context.Context, tag string) error { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for Untag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, tag) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewTagService creates a new instance of TagService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTagService(t interface { + mock.TestingT + Cleanup(func()) +}) *TagService { + mock := &TagService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/registry-scanner/pkg/registry/registry.go b/registry-scanner/pkg/registry/registry.go new file mode 100644 index 0000000..5c84b8f --- /dev/null +++ b/registry-scanner/pkg/registry/registry.go @@ -0,0 +1,222 @@ +package registry + +// Package registry implements functions for retrieving data from container +// registries. +// +// TODO: Refactor this package and provide mocks for better testing. + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/distribution/distribution/v3" + + "golang.org/x/sync/semaphore" + + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/image" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/kube" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log" + "github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/tag" +) + +const ( + MaxMetadataConcurrency = 20 +) + +// GetTags returns a list of available tags for the given image +func (endpoint *RegistryEndpoint) GetTags(img *image.ContainerImage, regClient RegistryClient, vc *image.VersionConstraint) (*tag.ImageTagList, error) { + var tagList *tag.ImageTagList = tag.NewImageTagList() + var err error + + logCtx := vc.Options.Logger() + + // Some registries have a default namespace that is used when the image name + // doesn't specify one. For example at Docker Hub, this is 'library'. + var nameInRegistry string + if len := len(strings.Split(img.ImageName, "/")); len == 1 && endpoint.DefaultNS != "" { + nameInRegistry = endpoint.DefaultNS + "/" + img.ImageName + logCtx.Debugf("Using canonical image name '%s' for image '%s'", nameInRegistry, img.ImageName) + } else { + nameInRegistry = img.ImageName + } + err = regClient.NewRepository(nameInRegistry) + if err != nil { + return nil, err + } + tTags, err := regClient.Tags() + if err != nil { + return nil, err + } + + tags := []string{} + + // For digest strategy, we do require a version constraint + if vc.Strategy.NeedsVersionConstraint() && vc.Constraint == "" { + return nil, fmt.Errorf("cannot use update strategy 'digest' for image '%s' without a version constraint", img.Original()) + } + + // Loop through tags, removing those we do not want. If update strategy is + // digest, all but the constraint tag are ignored. + if vc.MatchFunc != nil || len(vc.IgnoreList) > 0 || vc.Strategy.WantsOnlyConstraintTag() { + for _, t := range tTags { + if (vc.MatchFunc != nil && !vc.MatchFunc(t, vc.MatchArgs)) || vc.IsTagIgnored(t) || (vc.Strategy.WantsOnlyConstraintTag() && t != vc.Constraint) { + logCtx.Tracef("Removing tag %s because it either didn't match defined pattern or is ignored", t) + } else { + tags = append(tags, t) + } + } + } else { + tags = tTags + } + + // In some cases, we don't need to fetch the metadata to get the creation time + // stamp of from the image's meta data: + // + // - We use an update strategy other than latest or digest + // - The registry doesn't provide meta data and has tags sorted already + // + // We just create a dummy time stamp according to the registry's sort mode, if + // set. + if (vc.Strategy != image.StrategyNewestBuild && vc.Strategy != image.StrategyDigest) || endpoint.TagListSort.IsTimeSorted() { + for i, tagStr := range tags { + var ts int + if endpoint.TagListSort == TagListSortLatestFirst { + ts = len(tags) - i + } else if endpoint.TagListSort == TagListSortLatestLast { + ts = i + } + imgTag := tag.NewImageTag(tagStr, time.Unix(int64(ts), 0), "") + tagList.Add(imgTag) + } + return tagList, nil + } + + sem := semaphore.NewWeighted(int64(MaxMetadataConcurrency)) + tagListLock := &sync.RWMutex{} + + var wg sync.WaitGroup + wg.Add(len(tags)) + + // Fetch the manifest for the tag -- we need v1, because it contains history + // information that we require. + i := 0 + for _, tagStr := range tags { + i += 1 + // Look into the cache first and re-use any found item. If GetTag() returns + // an error, we treat it as a cache miss and just go ahead to invalidate + // the entry. + if vc.Strategy.IsCacheable() { + imgTag, err := endpoint.Cache.GetTag(nameInRegistry, tagStr) + if err != nil { + log.Warnf("invalid entry for %s:%s in cache, invalidating.", nameInRegistry, imgTag.TagName) + } else if imgTag != nil { + logCtx.Debugf("Cache hit for %s:%s", nameInRegistry, imgTag.TagName) + tagListLock.Lock() + tagList.Add(imgTag) + tagListLock.Unlock() + wg.Done() + continue + } + } + + logCtx.Tracef("Getting manifest for image %s:%s (operation %d/%d)", nameInRegistry, tagStr, i, len(tags)) + + lockErr := sem.Acquire(context.TODO(), 1) + if lockErr != nil { + log.Warnf("could not acquire semaphore: %v", lockErr) + wg.Done() + continue + } + logCtx.Tracef("acquired metadata semaphore") + + go func(tagStr string) { + defer func() { + sem.Release(1) + wg.Done() + log.Tracef("released semaphore and terminated waitgroup") + }() + + var ml distribution.Manifest + var err error + + // We first try to fetch a V2 manifest, and if that's not available we fall + // back to fetching V1 manifest. If that fails also, we just skip this tag. + if ml, err = regClient.ManifestForTag(tagStr); err != nil { + logCtx.Errorf("Error fetching metadata for %s:%s - neither V1 or V2 or OCI manifest returned by registry: %v", nameInRegistry, tagStr, err) + return + } + + // Parse required meta data from the manifest. The metadata contains all + // information needed to decide whether to consider this tag or not. + ti, err := regClient.TagMetadata(ml, vc.Options) + if err != nil { + logCtx.Errorf("error fetching metadata for %s:%s: %v", nameInRegistry, tagStr, err) + return + } + if ti == nil { + logCtx.Debugf("No metadata found for %s:%s", nameInRegistry, tagStr) + return + } + + logCtx.Tracef("Found date %s", ti.CreatedAt.String()) + var imgTag *tag.ImageTag + if vc.Strategy == image.StrategyDigest { + imgTag = tag.NewImageTag(tagStr, ti.CreatedAt, fmt.Sprintf("sha256:%x", ti.Digest)) + } else { + imgTag = tag.NewImageTag(tagStr, ti.CreatedAt, "") + } + tagListLock.Lock() + tagList.Add(imgTag) + tagListLock.Unlock() + endpoint.Cache.SetTag(nameInRegistry, imgTag) + }(tagStr) + } + + wg.Wait() + return tagList, err +} + +func (ep *RegistryEndpoint) expireCredentials() bool { + if ep.Credentials != "" && !ep.CredsUpdated.IsZero() && ep.CredsExpire > 0 && time.Since(ep.CredsUpdated) >= ep.CredsExpire { + ep.Username = "" + ep.Password = "" + return true + } + return false +} + +// Sets endpoint credentials for this registry from a reference to a K8s secret +func (ep *RegistryEndpoint) SetEndpointCredentials(kubeClient *kube.KubernetesClient) error { + if ep.expireCredentials() { + log.Debugf("expired credentials for registry %s (updated:%s, expiry:%0fs)", ep.RegistryAPI, ep.CredsUpdated, ep.CredsExpire.Seconds()) + } + if ep.Username == "" && ep.Password == "" && ep.Credentials != "" { + credSrc, err := image.ParseCredentialSource(ep.Credentials, false) + if err != nil { + return err + } + + // For fetching credentials, we must have working Kubernetes client. + if (credSrc.Type == image.CredentialSourcePullSecret || credSrc.Type == image.CredentialSourceSecret) && kubeClient == nil { + log.WithContext(). + AddField("registry", ep.RegistryAPI). + Warnf("cannot use K8s credentials without Kubernetes client") + return fmt.Errorf("could not fetch image tags") + } + + creds, err := credSrc.FetchCredentials(ep.RegistryAPI, kubeClient) + if err != nil { + return err + } + + ep.CredsUpdated = time.Now() + + ep.Username = creds.Username + ep.Password = creds.Password + } + + return nil +} diff --git a/registry-scanner/pkg/registry/registry_test.go b/registry-scanner/pkg/registry/registry_test.go new file mode 100644 index 0000000..ff525ce --- /dev/null +++ b/registry-scanner/pkg/registry/registry_test.go @@ -0,0 +1,157 @@ +package registry + +import ( + "os" + "testing" + "time" + + //nolint:staticcheck + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test relies on image package which is not available yet. Will uncomment as soon as it is available. +// func Test_GetTags(t *testing.T) { + +// t.Run("Check for correctly returned tags with semver sort", func(t *testing.T) { +// regClient := mocks.RegistryClient{} +// regClient.On("NewRepository", mock.Anything).Return(nil) +// regClient.On("Tags", mock.Anything).Return([]string{"1.2.0", "1.2.1", "1.2.2"}, nil) + +// ep, err := GetRegistryEndpoint("") +// require.NoError(t, err) + +// img := image.NewFromIdentifier("foo/bar:1.2.0") + +// tl, err := ep.GetTags(img, ®Client, &image.VersionConstraint{Strategy: image.StrategySemVer, Options: options.NewManifestOptions()}) +// require.NoError(t, err) +// assert.NotEmpty(t, tl) + +// tag, err := ep.Cache.GetTag("foo/bar", "1.2.1") +// require.NoError(t, err) +// assert.Nil(t, tag) +// }) + +// t.Run("Check for correctly returned tags with filter function applied", func(t *testing.T) { +// regClient := mocks.RegistryClient{} +// regClient.On("NewRepository", mock.Anything).Return(nil) +// regClient.On("Tags", mock.Anything).Return([]string{"1.2.0", "1.2.1", "1.2.2"}, nil) + +// ep, err := GetRegistryEndpoint("") +// require.NoError(t, err) + +// img := image.NewFromIdentifier("foo/bar:1.2.0") + +// tl, err := ep.GetTags(img, ®Client, &image.VersionConstraint{ +// Strategy: image.StrategySemVer, +// MatchFunc: image.MatchFuncNone, +// Options: options.NewManifestOptions()}) +// require.NoError(t, err) +// assert.Empty(t, tl.Tags()) + +// tag, err := ep.Cache.GetTag("foo/bar", "1.2.1") +// require.NoError(t, err) +// assert.Nil(t, tag) +// }) + +// t.Run("Check for correctly returned tags with name sort", func(t *testing.T) { + +// regClient := mocks.RegistryClient{} +// regClient.On("NewRepository", mock.Anything).Return(nil) +// regClient.On("Tags", mock.Anything).Return([]string{"1.2.0", "1.2.1", "1.2.2"}, nil) + +// ep, err := GetRegistryEndpoint("") +// require.NoError(t, err) + +// img := image.NewFromIdentifier("foo/bar:1.2.0") + +// tl, err := ep.GetTags(img, ®Client, &image.VersionConstraint{Strategy: image.StrategyAlphabetical, Options: options.NewManifestOptions()}) +// require.NoError(t, err) +// assert.NotEmpty(t, tl) + +// tag, err := ep.Cache.GetTag("foo/bar", "1.2.1") +// require.NoError(t, err) +// assert.Nil(t, tag) +// }) + +// t.Run("Check for correctly returned tags with latest sort", func(t *testing.T) { +// ts := "2006-01-02T15:04:05.999999999Z" +// meta1 := &schema1.SignedManifest{ //nolint:staticcheck +// Manifest: schema1.Manifest{ //nolint:staticcheck +// History: []schema1.History{ //nolint:staticcheck +// { +// V1Compatibility: `{"created":"` + ts + `"}`, +// }, +// }, +// }, +// } + +// regClient := mocks.RegistryClient{} +// regClient.On("NewRepository", mock.Anything).Return(nil) +// regClient.On("Tags", mock.Anything).Return([]string{"1.2.0", "1.2.1", "1.2.2"}, nil) +// regClient.On("ManifestForTag", mock.Anything, mock.Anything).Return(meta1, nil) +// regClient.On("TagMetadata", mock.Anything, mock.Anything).Return(&tag.TagInfo{}, nil) + +// ep, err := GetRegistryEndpoint("") +// require.NoError(t, err) +// ep.Cache.ClearCache() + +// img := image.NewFromIdentifier("foo/bar:1.2.0") +// tl, err := ep.GetTags(img, ®Client, &image.VersionConstraint{Strategy: image.StrategyNewestBuild, Options: options.NewManifestOptions()}) +// require.NoError(t, err) +// assert.NotEmpty(t, tl) + +// tag, err := ep.Cache.GetTag("foo/bar", "1.2.1") +// require.NoError(t, err) +// require.NotNil(t, tag) +// require.Equal(t, "1.2.1", tag.TagName) +// }) + +// } + +func Test_ExpireCredentials(t *testing.T) { + epYAML := ` +registries: +- name: GitHub Container Registry + api_url: https://ghcr.io + ping: no + prefix: ghcr.io + credentials: env:TEST_CREDS + credsexpire: 3s +` + t.Run("Expire credentials", func(t *testing.T) { + epl, err := ParseRegistryConfiguration(epYAML) + require.NoError(t, err) + require.Len(t, epl.Items, 1) + + // New registry configuration + err = AddRegistryEndpointFromConfig(epl.Items[0]) + require.NoError(t, err) + ep, err := GetRegistryEndpoint("ghcr.io") + require.NoError(t, err) + require.NotEqual(t, 0, ep.CredsExpire) + + // Initial creds + os.Setenv("TEST_CREDS", "foo:bar") + err = ep.SetEndpointCredentials(nil) + assert.NoError(t, err) + assert.Equal(t, "foo", ep.Username) + assert.Equal(t, "bar", ep.Password) + assert.False(t, ep.CredsUpdated.IsZero()) + + // Creds should still be cached + os.Setenv("TEST_CREDS", "bar:foo") + err = ep.SetEndpointCredentials(nil) + assert.NoError(t, err) + assert.Equal(t, "foo", ep.Username) + assert.Equal(t, "bar", ep.Password) + + // Pretend 5 minutes have passed - creds have expired and are re-read from env + ep.CredsUpdated = ep.CredsUpdated.Add(time.Minute * -5) + err = ep.SetEndpointCredentials(nil) + assert.NoError(t, err) + assert.Equal(t, "bar", ep.Username) + assert.Equal(t, "foo", ep.Password) + }) + +} diff --git a/registry-scanner/test/fake/kubernetes.go b/registry-scanner/test/fake/kubernetes.go new file mode 100644 index 0000000..cad8b50 --- /dev/null +++ b/registry-scanner/test/fake/kubernetes.go @@ -0,0 +1,16 @@ +package fake + +import ( + "k8s.io/apimachinery/pkg/runtime" + kubefake "k8s.io/client-go/kubernetes/fake" +) + +func NewFakeKubeClient() *kubefake.Clientset { + clientset := kubefake.NewSimpleClientset() + return clientset +} + +func NewFakeClientsetWithResources(objects ...runtime.Object) *kubefake.Clientset { + clientset := kubefake.NewSimpleClientset(objects...) + return clientset +} |
