Support passing `shield` at stream contruction
parent
4240efc7e3
commit
1f8966ba64
tractor
|
@ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
self,
|
self,
|
||||||
ctx: 'Context', # typing: ignore # noqa
|
ctx: 'Context', # typing: ignore # noqa
|
||||||
rx_chan: trio.abc.ReceiveChannel,
|
rx_chan: trio.abc.ReceiveChannel,
|
||||||
|
shield: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._rx_chan = rx_chan
|
self._rx_chan = rx_chan
|
||||||
self._shielded = False
|
self._shielded = shield
|
||||||
|
|
||||||
# delegate directly to underlying mem channel
|
# delegate directly to underlying mem channel
|
||||||
def receive_nowait(self):
|
def receive_nowait(self):
|
||||||
|
@ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
"""Shield this stream's underlying channel such that a local consumer task
|
"""Shield this stream's underlying channel such that a local consumer task
|
||||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||||
|
|
||||||
|
Note that here, "shielding" here guards against relaying
|
||||||
|
a ``'stop'`` message to the far end of the stream thus keeping
|
||||||
|
the stream machinery active and ready for further use, it does
|
||||||
|
not have anything to do with an internal ``trio.CancelScope``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._shielded = True
|
self._shielded = True
|
||||||
yield self
|
yield self
|
||||||
|
@ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
await self._ctx.send_stop()
|
await self._ctx.send_stop()
|
||||||
|
|
||||||
# close the local mem chan
|
# close the local mem chan
|
||||||
rx_chan.close()
|
await rx_chan.aclose()
|
||||||
|
|
||||||
# TODO: but make it broadcasting to consumers
|
# TODO: but make it broadcasting to consumers
|
||||||
# def clone(self):
|
# def clone(self):
|
||||||
|
@ -281,6 +287,7 @@ class Context:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def open_stream(
|
async def open_stream(
|
||||||
self,
|
self,
|
||||||
|
shield: bool = False,
|
||||||
) -> MsgStream:
|
) -> MsgStream:
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
|
@ -299,7 +306,11 @@ class Context:
|
||||||
self.cid
|
self.cid
|
||||||
)
|
)
|
||||||
|
|
||||||
async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan:
|
async with MsgStream(
|
||||||
|
ctx=self,
|
||||||
|
rx_chan=recv_chan,
|
||||||
|
shield=shield,
|
||||||
|
) as rchan:
|
||||||
|
|
||||||
if self._portal:
|
if self._portal:
|
||||||
self._portal._streams.add(rchan)
|
self._portal._streams.add(rchan)
|
||||||
|
@ -308,9 +319,11 @@ class Context:
|
||||||
yield rchan
|
yield rchan
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
# signal ``StopAsyncIteration`` on far end.
|
||||||
await self.send_stop()
|
await self.send_stop()
|
||||||
|
|
||||||
if self._portal:
|
if self._portal:
|
||||||
self._portal._streams.add(rchan)
|
self._portal._streams.remove(rchan)
|
||||||
|
|
||||||
async def started(self, value: Any) -> None:
|
async def started(self, value: Any) -> None:
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue