summaryrefslogtreecommitdiff
path: root/robusta_krr/utils/batched.py
blob: c673aaae21933a579bf2809f4292f690f55478c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import itertools
from typing import Iterable, TypeVar

_T = TypeVar("_T")


def batched(iterable: Iterable[_T], n: int) -> Iterable[list[_T]]:
    "Batch data into tuples of length n. The last batch may be shorter."
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError("n must be at least one")
    it = iter(iterable)
    while batch := list(itertools.islice(it, n)):
        yield batch