summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--robusta_krr/core/integrations/kubernetes/__init__.py97
-rw-r--r--robusta_krr/core/runner.py8
-rw-r--r--robusta_krr/utils/async_gen_merge.py39
-rw-r--r--tests/conftest.py13
4 files changed, 51 insertions, 106 deletions
diff --git a/robusta_krr/core/integrations/kubernetes/__init__.py b/robusta_krr/core/integrations/kubernetes/__init__.py
index 335b47a..a772a5c 100644
--- a/robusta_krr/core/integrations/kubernetes/__init__.py
+++ b/robusta_krr/core/integrations/kubernetes/__init__.py
@@ -2,7 +2,7 @@ import asyncio
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
-from typing import Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, Iterable, Optional, Union
+from typing import Any, Awaitable, Callable, Iterable, Optional, Union
from kubernetes import client, config # type: ignore
from kubernetes.client import ApiException
@@ -20,7 +20,6 @@ from kubernetes.client.models import (
from robusta_krr.core.models.config import settings
from robusta_krr.core.models.objects import HPAData, K8sObjectData, KindLiteral, PodData
from robusta_krr.core.models.result import ResourceAllocations
-from robusta_krr.utils.async_gen_merge import async_gen_merge
from robusta_krr.utils.object_like_dict import ObjectLikeDict
from . import config_patch as _
@@ -49,7 +48,7 @@ class ClusterLoader:
self.__jobs_for_cronjobs: dict[str, list[V1Job]] = {}
self.__jobs_loading_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
- async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]:
+ async def list_scannable_objects(self) -> list[K8sObjectData]:
"""List all scannable objects.
Returns:
@@ -61,10 +60,7 @@ class ClusterLoader:
logger.debug(f"Resources: {settings.resources}")
self.__hpa_list = await self._try_list_hpa()
-
- # https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python
- # This will merge all the streams from all the cluster loaders into a single stream
- async for object in async_gen_merge(
+ workload_object_lists = await asyncio.gather(
self._list_deployments(),
self._list_rollouts(),
self._list_deploymentconfig(),
@@ -72,11 +68,15 @@ class ClusterLoader:
self._list_all_daemon_set(),
self._list_all_jobs(),
self._list_all_cronjobs(),
- ):
+ )
+
+ return [
+ object
+ for workload_objects in workload_object_lists
+ for object in workload_objects
# NOTE: By default we will filter out kube-system namespace
- if settings.namespaces == "*" and object.namespace == "kube-system":
- continue
- yield object
+ if not (settings.namespaces == "*" and object.namespace == "kube-system")
+ ]
async def _list_jobs_for_cronjobs(self, namespace: str) -> list[V1Job]:
if namespace not in self.__jobs_for_cronjobs:
@@ -185,12 +185,12 @@ class ClusterLoader:
kind: KindLiteral,
all_namespaces_request: Callable,
namespaced_request: Callable
- ) -> AsyncIterable[Any]:
+ ) -> list[Any]:
logger.debug(f"Listing {kind}s in {self.cluster}")
loop = asyncio.get_running_loop()
if settings.namespaces == "*":
- tasks = [
+ requests = [
loop.run_in_executor(
self.executor,
lambda: all_namespaces_request(
@@ -200,7 +200,7 @@ class ClusterLoader:
)
]
else:
- tasks = [
+ requests = [
loop.run_in_executor(
self.executor,
lambda ns=namespace: namespaced_request(
@@ -212,14 +212,14 @@ class ClusterLoader:
for namespace in settings.namespaces
]
- total_items = 0
- for task in asyncio.as_completed(tasks):
- ret_single = await task
- total_items += len(ret_single.items)
- for item in ret_single.items:
- yield item
+ result = [
+ item
+ for request_result in await asyncio.gather(*requests)
+ for item in request_result.items
+ ]
- logger.debug(f"Found {total_items} {kind} in {self.cluster}")
+ logger.debug(f"Found {len(result)} {kind} in {self.cluster}")
+ return result
async def _list_scannable_objects(
self,
@@ -228,16 +228,17 @@ class ClusterLoader:
namespaced_request: Callable,
extract_containers: Callable[[Any], Union[Iterable[V1Container], Awaitable[Iterable[V1Container]]]],
filter_workflows: Optional[Callable[[Any], bool]] = None,
- ) -> AsyncIterable[K8sObjectData]:
+ ) -> list[K8sObjectData]:
if not self._should_list_resource(kind):
logger.debug(f"Skipping {kind}s in {self.cluster}")
return
if not self.__kind_available[kind]:
return
-
+
+ result = []
try:
- async for item in self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request):
+ for item in await self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request):
if filter_workflows is not None and not filter_workflows(item):
continue
@@ -245,8 +246,7 @@ class ClusterLoader:
if asyncio.iscoroutine(containers):
containers = await containers
- for container in containers:
- yield self.__build_scannable_object(item, container, kind)
+ result.extend(self.__build_scannable_object(item, container, kind) for container in containers)
except ApiException as e:
if kind in ("Rollout", "DeploymentConfig") and e.status in [400, 401, 403, 404]:
if self.__kind_available[kind]:
@@ -256,7 +256,9 @@ class ClusterLoader:
logger.exception(f"Error {e.status} listing {kind} in cluster {self.cluster}: {e.reason}")
logger.error("Will skip this object type and continue.")
- def _list_deployments(self) -> AsyncIterable[K8sObjectData]:
+ return result
+
+ def _list_deployments(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="Deployment",
all_namespaces_request=self.apps.list_deployment_for_all_namespaces,
@@ -264,7 +266,7 @@ class ClusterLoader:
extract_containers=lambda item: item.spec.template.spec.containers,
)
- def _list_rollouts(self) -> AsyncIterable[K8sObjectData]:
+ def _list_rollouts(self) -> list[K8sObjectData]:
async def _extract_containers(item: Any) -> list[V1Container]:
if item.spec.template is not None:
return item.spec.template.spec.containers
@@ -311,7 +313,7 @@ class ClusterLoader:
extract_containers=_extract_containers,
)
- def _list_deploymentconfig(self) -> AsyncIterable[K8sObjectData]:
+ def _list_deploymentconfig(self) -> list[K8sObjectData]:
# NOTE: Using custom objects API returns dicts, but all other APIs return objects
# We need to handle this difference using a small wrapper
return self._list_scannable_objects(
@@ -335,7 +337,7 @@ class ClusterLoader:
extract_containers=lambda item: item.spec.template.spec.containers,
)
- def _list_all_statefulsets(self) -> AsyncIterable[K8sObjectData]:
+ def _list_all_statefulsets(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="StatefulSet",
all_namespaces_request=self.apps.list_stateful_set_for_all_namespaces,
@@ -343,7 +345,7 @@ class ClusterLoader:
extract_containers=lambda item: item.spec.template.spec.containers,
)
- def _list_all_daemon_set(self) -> AsyncIterable[K8sObjectData]:
+ def _list_all_daemon_set(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="DaemonSet",
all_namespaces_request=self.apps.list_daemon_set_for_all_namespaces,
@@ -351,7 +353,7 @@ class ClusterLoader:
extract_containers=lambda item: item.spec.template.spec.containers,
)
- def _list_all_jobs(self) -> AsyncIterable[K8sObjectData]:
+ def _list_all_jobs(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="Job",
all_namespaces_request=self.batch.list_job_for_all_namespaces,
@@ -363,7 +365,7 @@ class ClusterLoader:
),
)
- def _list_all_cronjobs(self) -> AsyncIterable[K8sObjectData]:
+ def _list_all_cronjobs(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="CronJob",
all_namespaces_request=self.batch.list_cron_job_for_all_namespaces,
@@ -398,14 +400,10 @@ class ClusterLoader:
}
async def __list_hpa_v2(self) -> dict[HPAKey, HPAData]:
- loop = asyncio.get_running_loop()
- res = await loop.run_in_executor(
- self.executor,
- lambda: self._list_namespaced_or_global_objects(
- kind="HPA-v2",
- all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces,
- namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler,
- ),
+ res = await self._list_namespaced_or_global_objects(
+ kind="HPA-v2",
+ all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces,
+ namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler,
)
def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[float]:
return next(
@@ -429,7 +427,7 @@ class ClusterLoader:
target_cpu_utilization_percentage=__get_metric(hpa, "cpu"),
target_memory_utilization_percentage=__get_metric(hpa, "memory"),
)
- async for hpa in res
+ for hpa in res
}
# TODO: What should we do in case of other metrics bound to the HPA?
@@ -514,7 +512,7 @@ class KubernetesLoader:
logger.error(f"Could not load cluster {cluster} and will skip it: {e}")
return None
- async def list_scannable_objects(self, clusters: Optional[list[str]]) -> AsyncIterable[K8sObjectData]:
+ async def list_scannable_objects(self, clusters: Optional[list[str]]) -> list[K8sObjectData]:
"""List all scannable objects.
Yields:
@@ -529,13 +527,12 @@ class KubernetesLoader:
if self.cluster_loaders == {}:
logger.error("Could not load any cluster.")
return
-
- # https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python
- # This will merge all the streams from all the cluster loaders into a single stream
- async for object in async_gen_merge(
- *[cluster_loader.list_scannable_objects() for cluster_loader in self.cluster_loaders.values()]
- ):
- yield object
+
+ return [
+ object
+ for cluster_loader in self.cluster_loaders.values()
+ for object in await cluster_loader.list_scannable_objects()
+ ]
async def load_pods(self, object: K8sObjectData) -> list[PodData]:
try:
diff --git a/robusta_krr/core/runner.py b/robusta_krr/core/runner.py
index 546dd01..8e08521 100644
--- a/robusta_krr/core/runner.py
+++ b/robusta_krr/core/runner.py
@@ -275,12 +275,8 @@ class Runner:
await asyncio.gather(*[self._check_data_availability(cluster) for cluster in clusters])
with ProgressBar(title="Calculating Recommendation") as self.__progressbar:
- scans_tasks = [
- asyncio.create_task(self._gather_object_allocations(k8s_object))
- async for k8s_object in self._k8s_loader.list_scannable_objects(clusters)
- ]
-
- scans = await asyncio.gather(*scans_tasks)
+ workloads = await self._k8s_loader.list_scannable_objects(clusters)
+ scans = await asyncio.gather(*[self._gather_object_allocations(k8s_object) for k8s_object in workloads])
successful_scans = [scan for scan in scans if scan is not None]
diff --git a/robusta_krr/utils/async_gen_merge.py b/robusta_krr/utils/async_gen_merge.py
deleted file mode 100644
index 35c2c86..0000000
--- a/robusta_krr/utils/async_gen_merge.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import asyncio
-import logging
-from typing import AsyncIterable, TypeVar
-
-
-logger = logging.getLogger("krr")
-
-
-# Define a type variable for the values yielded by the async generators
-T = TypeVar("T")
-
-
-def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]:
- queue = asyncio.Queue()
- iters_remaining = set(aiters)
-
- async def drain(aiter):
- try:
- async for item in aiter:
- await queue.put(item)
- except Exception:
- logger.exception(f"Error in async generator {aiter}")
- finally:
- iters_remaining.discard(aiter)
- await queue.put(None)
-
- async def merged():
- while iters_remaining or not queue.empty():
- item = await queue.get()
-
- if item is None:
- continue
-
- yield item
-
- for aiter in aiters:
- asyncio.create_task(drain(aiter))
-
- return merged()
diff --git a/tests/conftest.py b/tests/conftest.py
index 61c389d..b1d8d22 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,6 @@
import random
from datetime import datetime, timedelta
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
@@ -26,15 +26,6 @@ TEST_OBJECT = K8sObjectData(
)
-class AsyncIter:
- def __init__(self, items):
- self.items = items
-
- async def __aiter__(self):
- for item in self.items:
- yield item
-
-
@pytest.fixture(autouse=True, scope="session")
def mock_list_clusters():
with patch(
@@ -48,7 +39,7 @@ def mock_list_clusters():
def mock_list_scannable_objects():
with patch(
"robusta_krr.core.integrations.kubernetes.KubernetesLoader.list_scannable_objects",
- new=MagicMock(return_value=AsyncIter([TEST_OBJECT])),
+ new=AsyncMock(return_value=[TEST_OBJECT]),
):
yield