Support passing `shield` at stream contruction
parent
b5116c5a51
commit
e46ef8ae3f
|
@ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
self,
|
||||
ctx: 'Context', # typing: ignore # noqa
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
shield: bool = False,
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
self._shielded = False
|
||||
self._shielded = shield
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
|
@ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
Note that here, "shielding" here guards against relaying
|
||||
a ``'stop'`` message to the far end of the stream thus keeping
|
||||
the stream machinery active and ready for further use, it does
|
||||
not have anything to do with an internal ``trio.CancelScope``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
|
@ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
await self._ctx.send_stop()
|
||||
|
||||
# close the local mem chan
|
||||
rx_chan.close()
|
||||
await rx_chan.aclose()
|
||||
|
||||
# TODO: but make it broadcasting to consumers
|
||||
# def clone(self):
|
||||
|
@ -281,6 +287,7 @@ class Context:
|
|||
@asynccontextmanager
|
||||
async def open_stream(
|
||||
self,
|
||||
shield: bool = False,
|
||||
) -> MsgStream:
|
||||
# TODO
|
||||
|
||||
|
@ -299,7 +306,11 @@ class Context:
|
|||
self.cid
|
||||
)
|
||||
|
||||
async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan:
|
||||
async with MsgStream(
|
||||
ctx=self,
|
||||
rx_chan=recv_chan,
|
||||
shield=shield,
|
||||
) as rchan:
|
||||
|
||||
if self._portal:
|
||||
self._portal._streams.add(rchan)
|
||||
|
@ -308,9 +319,11 @@ class Context:
|
|||
yield rchan
|
||||
|
||||
finally:
|
||||
# signal ``StopAsyncIteration`` on far end.
|
||||
await self.send_stop()
|
||||
|
||||
if self._portal:
|
||||
self._portal._streams.add(rchan)
|
||||
self._portal._streams.remove(rchan)
|
||||
|
||||
async def started(self, value: Any) -> None:
|
||||
|
||||
|
|
Loading…
Reference in New Issue