Support passing `shield` at stream contruction

transport_hardening
Tyler Goodlet 2021-05-07 11:20:51 -04:00
parent bc689427ef
commit 66d18be2ec
1 changed files with 17 additions and 4 deletions

View File

@ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
self, self,
ctx: 'Context', # typing: ignore # noqa ctx: 'Context', # typing: ignore # noqa
rx_chan: trio.abc.ReceiveChannel, rx_chan: trio.abc.ReceiveChannel,
shield: bool = False,
) -> None: ) -> None:
self._ctx = ctx self._ctx = ctx
self._rx_chan = rx_chan self._rx_chan = rx_chan
self._shielded = False self._shielded = shield
# delegate directly to underlying mem channel # delegate directly to underlying mem channel
def receive_nowait(self): def receive_nowait(self):
@ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
"""Shield this stream's underlying channel such that a local consumer task """Shield this stream's underlying channel such that a local consumer task
can be cancelled (and possibly restarted) using ``trio.Cancelled``. can be cancelled (and possibly restarted) using ``trio.Cancelled``.
Note that here, "shielding" here guards against relaying
a ``'stop'`` message to the far end of the stream thus keeping
the stream machinery active and ready for further use, it does
not have anything to do with an internal ``trio.CancelScope``.
""" """
self._shielded = True self._shielded = True
yield self yield self
@ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
await self._ctx.send_stop() await self._ctx.send_stop()
# close the local mem chan # close the local mem chan
rx_chan.close() await rx_chan.aclose()
# TODO: but make it broadcasting to consumers # TODO: but make it broadcasting to consumers
# def clone(self): # def clone(self):
@ -281,6 +287,7 @@ class Context:
@asynccontextmanager @asynccontextmanager
async def open_stream( async def open_stream(
self, self,
shield: bool = False,
) -> MsgStream: ) -> MsgStream:
# TODO # TODO
@ -299,7 +306,11 @@ class Context:
self.cid self.cid
) )
async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan: async with MsgStream(
ctx=self,
rx_chan=recv_chan,
shield=shield,
) as rchan:
if self._portal: if self._portal:
self._portal._streams.add(rchan) self._portal._streams.add(rchan)
@ -308,9 +319,11 @@ class Context:
yield rchan yield rchan
finally: finally:
# signal ``StopAsyncIteration`` on far end.
await self.send_stop() await self.send_stop()
if self._portal: if self._portal:
self._portal._streams.add(rchan) self._portal._streams.remove(rchan)
async def started(self, value: Any) -> None: async def started(self, value: Any) -> None: