From 5ff5e7a6ef824381548e33f074dd8cadcc91024e Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Fri, 7 May 2021 11:20:51 -0400 Subject: [PATCH] Support passing `shield` at stream contruction --- tractor/_streaming.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 1c46801..10be2c6 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): self, ctx: 'Context', # typing: ignore # noqa rx_chan: trio.abc.ReceiveChannel, + shield: bool = False, ) -> None: self._ctx = ctx self._rx_chan = rx_chan - self._shielded = False + self._shielded = shield # delegate directly to underlying mem channel def receive_nowait(self): @@ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): """Shield this stream's underlying channel such that a local consumer task can be cancelled (and possibly restarted) using ``trio.Cancelled``. + Note that here, "shielding" here guards against relaying + a ``'stop'`` message to the far end of the stream thus keeping + the stream machinery active and ready for further use, it does + not have anything to do with an internal ``trio.CancelScope``. + """ self._shielded = True yield self @@ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): await self._ctx.send_stop() # close the local mem chan - rx_chan.close() + await rx_chan.aclose() # TODO: but make it broadcasting to consumers # def clone(self): @@ -281,6 +287,7 @@ class Context: @asynccontextmanager async def open_stream( self, + shield: bool = False, ) -> MsgStream: # TODO @@ -299,7 +306,11 @@ class Context: self.cid ) - async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan: + async with MsgStream( + ctx=self, + rx_chan=recv_chan, + shield=shield, + ) as rchan: if self._portal: self._portal._streams.add(rchan) @@ -308,9 +319,11 @@ class Context: yield rchan finally: + # signal ``StopAsyncIteration`` on far end. await self.send_stop() + if self._portal: - self._portal._streams.add(rchan) + self._portal._streams.remove(rchan) async def started(self, value: Any) -> None: