summaryrefslogtreecommitdiff
path: root/robusta_krr/utils
diff options
context:
space:
mode:
authorLeaveMyYard <zhukovpavel2001@gmail.com>2024-03-14 13:06:14 +0200
committerLeaveMyYard <zhukovpavel2001@gmail.com>2024-03-14 13:06:14 +0200
commit4c8e727205221f4d38a638b3ed51dea568b26309 (patch)
tree7cae5b7cc635d58a7db7489d670fb8ab0286a03c /robusta_krr/utils
parent50147bd63ea57246d6b2653a849fb16da26f6339 (diff)
Remove aiostream
Diffstat (limited to 'robusta_krr/utils')
-rw-r--r--robusta_krr/utils/async_gen_merge.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/robusta_krr/utils/async_gen_merge.py b/robusta_krr/utils/async_gen_merge.py
new file mode 100644
index 0000000..7152895
--- /dev/null
+++ b/robusta_krr/utils/async_gen_merge.py
@@ -0,0 +1,49 @@
+import asyncio
+import logging
+from typing import AsyncIterable, TypeVar
+
+
+logger = logging.getLogger("krr")
+
+
+# Define a type variable for the values yielded by the async generators
+T = TypeVar("T")
+
+
+def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]:
+ queue = asyncio.Queue(1)
+ run_count = len(aiters)
+ cancelling = False
+
+ async def drain(aiter):
+ nonlocal run_count
+ try:
+ async for item in aiter:
+ await queue.put((False, item))
+ except Exception as e:
+ if not cancelling:
+ await queue.put((True, e))
+ else:
+ raise
+ finally:
+ run_count -= 1
+
+ async def merged():
+ try:
+ while run_count:
+ raised, next_item = await queue.get()
+ if raised:
+ cancel_tasks()
+ raise next_item
+ yield next_item
+ finally:
+ cancel_tasks()
+
+ def cancel_tasks():
+ nonlocal cancelling
+ cancelling = True
+ for t in tasks:
+ t.cancel()
+
+ tasks = [asyncio.create_task(drain(aiter)) for aiter in aiters]
+ return merged()