Support passing `shield` at stream contruction

wip_fix_asyncio_gen_streaming
Tyler Goodlet 2021-05-07 11:20:51 -04:00
parent b5116c5a51
commit e46ef8ae3f
1 changed files with 17 additions and 4 deletions

View File

@ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
self,
ctx: 'Context', # typing: ignore # noqa
rx_chan: trio.abc.ReceiveChannel,
shield: bool = False,
) -> None:
self._ctx = ctx
self._rx_chan = rx_chan
self._shielded = False
self._shielded = shield
# delegate directly to underlying mem channel
def receive_nowait(self):
@ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
"""Shield this stream's underlying channel such that a local consumer task
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
Note that here, "shielding" here guards against relaying
a ``'stop'`` message to the far end of the stream thus keeping
the stream machinery active and ready for further use, it does
not have anything to do with an internal ``trio.CancelScope``.
"""
self._shielded = True
yield self
@ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
await self._ctx.send_stop()
# close the local mem chan
rx_chan.close()
await rx_chan.aclose()
# TODO: but make it broadcasting to consumers
# def clone(self):
@ -281,6 +287,7 @@ class Context:
@asynccontextmanager
async def open_stream(
self,
shield: bool = False,
) -> MsgStream:
# TODO
@ -299,7 +306,11 @@ class Context:
self.cid
)
async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan:
async with MsgStream(
ctx=self,
rx_chan=recv_chan,
shield=shield,
) as rchan:
if self._portal:
self._portal._streams.add(rchan)
@ -308,9 +319,11 @@ class Context:
yield rchan
finally:
# signal ``StopAsyncIteration`` on far end.
await self.send_stop()
if self._portal:
self._portal._streams.add(rchan)
self._portal._streams.remove(rchan)
async def started(self, value: Any) -> None: