diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 171ca4b..fbf253b 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # flag to denote end of stream self._eoc: bool = False + self._closed: bool = False # delegate directly to underlying mem channel def receive_nowait(self): @@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): msg = await self._rx_chan.receive() return msg['yield'] - except KeyError: + except KeyError as err: # internal error should never get here assert msg.get('cid'), ("Received internal error at portal?") @@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # - 'error' # possibly just handle msg['stop'] here! - if msg.get('stop'): + if msg.get('stop') or self._eoc: log.debug(f"{self} was stopped at remote end") + # XXX: important to set so that a new ``.receive()`` + # call (likely by another task using a broadcast receiver) + # doesn't accidentally pull the ``return`` message + # value out of the underlying feed mem chan! + self._eoc = True + # # when the send is closed we assume the stream has # # terminated and signal this local iterator to stop # await self.aclose() @@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # XXX: this causes ``ReceiveChannel.__anext__()`` to # raise a ``StopAsyncIteration`` **and** in our catch # block below it will trigger ``.aclose()``. - raise trio.EndOfChannel + raise trio.EndOfChannel from err # TODO: test that shows stream raising an expected error!!! elif msg.get('error'): @@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): raise # propagate async def aclose(self): - """Cancel associated remote actor task and local memory channel - on close. + ''' + Cancel associated remote actor task and local memory channel on + close. - """ + ''' # XXX: keep proper adherance to trio's `.aclose()` semantics: # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose rx_chan = self._rx_chan @@ -178,7 +186,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose return - self._eoc = True + self._closed = True # NOTE: this is super subtle IPC messaging stuff: # Relay stop iteration to far end **iff** we're @@ -310,15 +318,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel): self, data: Any ) -> None: - '''Send a message over this stream to the far end. + ''' + Send a message over this stream to the far end. ''' - # if self._eoc: - # raise trio.ClosedResourceError('This stream is already ded') - if self._ctx._error: raise self._ctx._error # from None + if self._closed: + raise trio.ClosedResourceError('This stream was already closed') + await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})