summaryrefslogtreecommitdiff
path: root/robusta_krr/utils/async_gen_merge.py
blob: 35c2c866c320c2f7bf1896ddd3ae992f52e49f02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import asyncio
import logging
from typing import AsyncIterable, TypeVar


logger = logging.getLogger("krr")


# Define a type variable for the values yielded by the async generators
T = TypeVar("T")


def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]:
    queue = asyncio.Queue()
    iters_remaining = set(aiters)

    async def drain(aiter):
        try:
            async for item in aiter:
                await queue.put(item)
        except Exception:
            logger.exception(f"Error in async generator {aiter}")
        finally:
            iters_remaining.discard(aiter)
            await queue.put(None)

    async def merged():
        while iters_remaining or not queue.empty():
            item = await queue.get()

            if item is None:
                continue

            yield item

    for aiter in aiters:
        asyncio.create_task(drain(aiter))

    return merged()