summaryrefslogtreecommitdiff
path: root/robusta_krr/utils/async_gen_merge.py
diff options
context:
space:
mode:
Diffstat (limited to 'robusta_krr/utils/async_gen_merge.py')
-rw-r--r--robusta_krr/utils/async_gen_merge.py42
1 files changed, 16 insertions, 26 deletions
diff --git a/robusta_krr/utils/async_gen_merge.py b/robusta_krr/utils/async_gen_merge.py
index 7152895..35c2c86 100644
--- a/robusta_krr/utils/async_gen_merge.py
+++ b/robusta_krr/utils/async_gen_merge.py
@@ -11,39 +11,29 @@ T = TypeVar("T")
def async_gen_merge(*aiters: AsyncIterable[T]) -> AsyncIterable[T]:
- queue = asyncio.Queue(1)
- run_count = len(aiters)
- cancelling = False
+ queue = asyncio.Queue()
+ iters_remaining = set(aiters)
async def drain(aiter):
- nonlocal run_count
try:
async for item in aiter:
- await queue.put((False, item))
- except Exception as e:
- if not cancelling:
- await queue.put((True, e))
- else:
- raise
+ await queue.put(item)
+ except Exception:
+ logger.exception(f"Error in async generator {aiter}")
finally:
- run_count -= 1
+ iters_remaining.discard(aiter)
+ await queue.put(None)
async def merged():
- try:
- while run_count:
- raised, next_item = await queue.get()
- if raised:
- cancel_tasks()
- raise next_item
- yield next_item
- finally:
- cancel_tasks()
+ while iters_remaining or not queue.empty():
+ item = await queue.get()
+
+ if item is None:
+ continue
+
+ yield item
- def cancel_tasks():
- nonlocal cancelling
- cancelling = True
- for t in tasks:
- t.cancel()
+ for aiter in aiters:
+ asyncio.create_task(drain(aiter))
- tasks = [asyncio.create_task(drain(aiter)) for aiter in aiters]
return merged()