diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 1e85d04..ba8052f 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -3,7 +3,7 @@ import itertools import trio import tractor from tractor import open_actor_cluster -from tractor.trionics import async_enter_all +from tractor.trionics import gather_contexts from conftest import tractor_test @@ -25,10 +25,10 @@ async def worker(ctx: tractor.Context) -> None: async def test_streaming_to_actor_cluster() -> None: async with ( open_actor_cluster(modules=[__name__]) as portals, - async_enter_all( + gather_contexts( mngrs=[p.open_context(worker) for p in portals.values()], ) as contexts, - async_enter_all( + gather_contexts( mngrs=[ctx[0].open_stream() for ctx in contexts], ) as streams, ): diff --git a/tractor/trionics/__init__.py b/tractor/trionics/__init__.py index f7e90c6..3d1b9bd 100644 --- a/tractor/trionics/__init__.py +++ b/tractor/trionics/__init__.py @@ -2,12 +2,12 @@ Sugary patterns for trio + tractor designs. ''' -from ._mngrs import async_enter_all +from ._mngrs import gather_contexts from ._broadcast import broadcast_receiver, BroadcastReceiver, Lagged __all__ = [ - 'async_enter_all', + 'gather_contexts', 'broadcast_receiver', 'BroadcastReceiver', 'Lagged', diff --git a/tractor/trionics/_mngrs.py b/tractor/trionics/_mngrs.py index ca75c55..d31f1e0 100644 --- a/tractor/trionics/_mngrs.py +++ b/tractor/trionics/_mngrs.py @@ -14,11 +14,15 @@ T = TypeVar("T") async def _enter_and_wait( + mngr: AsyncContextManager[T], unwrapped: dict[int, T], all_entered: trio.Event, + parent_exit: trio.Event, + ) -> None: - '''Open the async context manager deliver it's value + ''' + Open the async context manager deliver it's value to this task's spawner and sleep until cancelled. ''' @@ -28,16 +32,31 @@ async def _enter_and_wait( if all(unwrapped.values()): all_entered.set() - await trio.sleep_forever() + await parent_exit.wait() @acm -async def async_enter_all( +async def gather_contexts( + mngrs: Sequence[AsyncContextManager[T]], + ) -> AsyncGenerator[tuple[T, ...], None]: + ''' + Concurrently enter a sequence of async context managers, each in + a separate ``trio`` task and deliver the unwrapped values in the + same order once all managers have entered. On exit all contexts are + subsequently and concurrently exited. + + This function is somewhat similar to common usage of + ``contextlib.AsyncExitStack.enter_async_context()`` (in a loop) in + combo with ``asyncio.gather()`` except the managers are concurrently + entered and exited cancellation just works. + + ''' unwrapped = {}.fromkeys(id(mngr) for mngr in mngrs) all_entered = trio.Event() + parent_exit = trio.Event() async with trio.open_nursery() as n: for mngr in mngrs: @@ -46,6 +65,7 @@ async def async_enter_all( mngr, unwrapped, all_entered, + parent_exit, ) # deliver control once all managers have started up @@ -53,4 +73,6 @@ async def async_enter_all( yield tuple(unwrapped.values()) - n.cancel_scope.cancel() \ No newline at end of file + # we don't need a try/finally since cancellation will be triggered + # by the surrounding nursery on error. + parent_exit.set()