Support passing `shield` at stream contruction
							parent
							
								
									b5116c5a51
								
							
						
					
					
						commit
						e46ef8ae3f
					
				|  | @ -37,10 +37,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|         self, | ||||
|         ctx: 'Context',  # typing: ignore # noqa | ||||
|         rx_chan: trio.abc.ReceiveChannel, | ||||
|         shield: bool = False, | ||||
|     ) -> None: | ||||
|         self._ctx = ctx | ||||
|         self._rx_chan = rx_chan | ||||
|         self._shielded = False | ||||
|         self._shielded = shield | ||||
| 
 | ||||
|     # delegate directly to underlying mem channel | ||||
|     def receive_nowait(self): | ||||
|  | @ -112,6 +113,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|         """Shield this stream's underlying channel such that a local consumer task | ||||
|         can be cancelled (and possibly restarted) using ``trio.Cancelled``. | ||||
| 
 | ||||
|         Note that here, "shielding" here guards against relaying | ||||
|         a ``'stop'`` message to the far end of the stream thus keeping | ||||
|         the stream machinery active and ready for further use, it does | ||||
|         not have anything to do with an internal ``trio.CancelScope``. | ||||
| 
 | ||||
|         """ | ||||
|         self._shielded = True | ||||
|         yield self | ||||
|  | @ -162,7 +168,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|             await self._ctx.send_stop() | ||||
| 
 | ||||
|         # close the local mem chan | ||||
|         rx_chan.close() | ||||
|         await rx_chan.aclose() | ||||
| 
 | ||||
|     # TODO: but make it broadcasting to consumers | ||||
|     # def clone(self): | ||||
|  | @ -281,6 +287,7 @@ class Context: | |||
|     @asynccontextmanager | ||||
|     async def open_stream( | ||||
|         self, | ||||
|         shield: bool = False, | ||||
|     ) -> MsgStream: | ||||
|         # TODO | ||||
| 
 | ||||
|  | @ -299,7 +306,11 @@ class Context: | |||
|             self.cid | ||||
|         ) | ||||
| 
 | ||||
|         async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan: | ||||
|         async with MsgStream( | ||||
|             ctx=self, | ||||
|             rx_chan=recv_chan, | ||||
|             shield=shield, | ||||
|         ) as rchan: | ||||
| 
 | ||||
|             if self._portal: | ||||
|                 self._portal._streams.add(rchan) | ||||
|  | @ -308,9 +319,11 @@ class Context: | |||
|                 yield rchan | ||||
| 
 | ||||
|             finally: | ||||
|                 # signal ``StopAsyncIteration`` on far end. | ||||
|                 await self.send_stop() | ||||
| 
 | ||||
|                 if self._portal: | ||||
|                     self._portal._streams.add(rchan) | ||||
|                     self._portal._streams.remove(rchan) | ||||
| 
 | ||||
|     async def started(self, value: Any) -> None: | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue