diff options
| author | LeaveMyYard <zhukovpavel2001@gmail.com> | 2024-03-14 13:06:14 +0200 |
|---|---|---|
| committer | LeaveMyYard <zhukovpavel2001@gmail.com> | 2024-03-14 13:06:14 +0200 |
| commit | 4c8e727205221f4d38a638b3ed51dea568b26309 (patch) | |
| tree | 7cae5b7cc635d58a7db7489d670fb8ab0286a03c /robusta_krr/utils | |
| parent | 50147bd63ea57246d6b2653a849fb16da26f6339 (diff) | |
Remove aiostream
Diffstat (limited to 'robusta_krr/utils')
| -rw-r--r-- | robusta_krr/utils/async_gen_merge.py | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/robusta_krr/utils/async_gen_merge.py b/robusta_krr/utils/async_gen_merge.py new file mode 100644 index 0000000..7152895 --- /dev/null +++ b/robusta_krr/utils/async_gen_merge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +from typing import AsyncIterable, TypeVar + + +logger = logging.getLogger("krr") + + +# Define a type variable for the values yielded by the async generators +T = TypeVar("T") + + +def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]: + queue = asyncio.Queue(1) + run_count = len(aiters) + cancelling = False + + async def drain(aiter): + nonlocal run_count + try: + async for item in aiter: + await queue.put((False, item)) + except Exception as e: + if not cancelling: + await queue.put((True, e)) + else: + raise + finally: + run_count -= 1 + + async def merged(): + try: + while run_count: + raised, next_item = await queue.get() + if raised: + cancel_tasks() + raise next_item + yield next_item + finally: + cancel_tasks() + + def cancel_tasks(): + nonlocal cancelling + cancelling = True + for t in tasks: + t.cancel() + + tasks = [asyncio.create_task(drain(aiter)) for aiter in aiters] + return merged() |
