From 15ead6b561e0cce87d5bcbaecac308bf8f7f49c2 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Wed, 16 Dec 2020 21:42:28 -0500 Subject: [PATCH] Add a way to shield a stream's underlying channel Add a ``tractor._portal.StreamReceiveChannel.shield_channel()`` context manager which allows for avoiding the closing of an IPC stream's underlying channel for the purposes of task re-spawning. Sometimes you might want to cancel a task consuming a stream but not tear down the IPC between actors (the default). A common use can might be where the task's "setup" work might need to be redone but you want to keep the established portal / channel in tact despite the task restart. Includes a test. --- tests/test_streaming.py | 73 +++++++++++++++++++++++++++++++++++++++++ tractor/_portal.py | 20 +++++++++++ 2 files changed, 93 insertions(+) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 919b278..c7ea4e0 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -7,6 +7,7 @@ import platform import trio import tractor +from tractor.testing import tractor_test import pytest @@ -53,6 +54,7 @@ async def stream_from_single_subactor(stream_func_name): """Verify we can spawn a daemon actor and retrieve streamed data. """ async with tractor.find_actor('streamerd') as portals: + if not portals: # only one per host address, spawns an actor if None async with tractor.open_nursery() as nursery: @@ -73,8 +75,10 @@ async def stream_from_single_subactor(stream_func_name): # it'd sure be nice to have an asyncitertools here... iseq = iter(seq) ival = next(iseq) + async for val in stream: assert val == ival + try: ival = next(iseq) except StopIteration: @@ -83,6 +87,7 @@ async def stream_from_single_subactor(stream_func_name): await stream.aclose() await trio.sleep(0.3) + try: await stream.__anext__() except StopAsyncIteration: @@ -109,8 +114,11 @@ def test_stream_from_single_subactor(arb_addr, start_method, stream_func): # this is the first 2 actors, streamer_1 and streamer_2 async def stream_data(seed): + for i in range(seed): + yield i + # trigger scheduler to simulate practical usage await trio.sleep(0) @@ -246,3 +254,68 @@ def test_not_fast_enough_quad( else: # should be cancelled mid-streaming assert results is None + + +@tractor_test +async def test_respawn_consumer_task( + arb_addr, + spawn_backend, + loglevel, +): + """Verify that ``._portal.StreamReceiveChannel.shield_channel()`` + sucessfully protects the underlying IPC channel from being closed + when cancelling and respawning a consumer task. + + This also serves to verify that all values from the stream can be + received despite the respawns. + + """ + stream = None + + async with tractor.open_nursery() as n: + + stream = await(await n.run_in_actor( + 'streamer', + stream_data, + seed=11, + )).result() + + expect = set(range(11)) + received = [] + + # this is the re-spawn task routine + async def consume(task_status=trio.TASK_STATUS_IGNORED): + print('starting consume task..') + nonlocal stream + + with trio.CancelScope() as cs: + task_status.started(cs) + + # shield stream's underlying channel from cancellation + with stream.shield_channel(): + + async for v in stream: + print(f'from stream: {v}') + expect.remove(v) + received.append(v) + + print('exited consume') + + async with trio.open_nursery() as ln: + cs = await ln.start(consume) + + while True: + + await trio.sleep(0.1) + + if received[-1] % 2 == 0: + + print('cancelling consume task..') + cs.cancel() + + # respawn + cs = await ln.start(consume) + + if not expect: + print("all values streamed, BREAKING") + break diff --git a/tractor/_portal.py b/tractor/_portal.py index af3c1b5..50036a8 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -7,6 +7,7 @@ import typing from typing import Tuple, Any, Dict, Optional, Set from functools import partial from dataclasses import dataclass +from contextlib import contextmanager import trio from async_generator import asynccontextmanager @@ -59,6 +60,7 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): self._cid = cid self._rx_chan = rx_chan self._portal = portal + self._shielded = False # delegate directly to underlying mem channel def receive_nowait(self): @@ -83,6 +85,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): "Received internal error at portal?") raise unpack_error(msg, self._portal.channel) + @contextmanager + def shield_channel( + self + ) -> typing.AsyncGenerator['StreamReceiveChannel', None]: + """Shield this stream's underlying channel such that a local consumer task + can be cancelled (and possibly restarted) using ``trio.Cancelled``. + + """ + self._shielded = True + yield self + self._shielded = False + async def aclose(self): """Cancel associated remote actor task and local memory channel on close. @@ -90,12 +104,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): if self._rx_chan._closed: log.warning(f"{self} is already closed") return + + if self._shielded: + log.warning(f"{self} is shielded, portal channel being kept alive") + return + cid = self._cid with trio.move_on_after(0.5) as cs: cs.shield = True log.warning( f"Cancelling stream {cid} to " f"{self._portal.channel.uid}") + # NOTE: we're telling the far end actor to cancel a task # corresponding to *this actor*. The far end local channel # instance is passed to `Actor._cancel_task()` implicitly.