Mark stream with EOC when stop message is received

end_of_channel_fixes
Tyler Goodlet 2021-12-15 16:22:04 -05:00
parent 79d63585b0
commit f2ba961e81
1 changed files with 20 additions and 11 deletions

View File

@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# flag to denote end of stream # flag to denote end of stream
self._eoc: bool = False self._eoc: bool = False
self._closed: bool = False
# delegate directly to underlying mem channel # delegate directly to underlying mem channel
def receive_nowait(self): def receive_nowait(self):
@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
msg = await self._rx_chan.receive() msg = await self._rx_chan.receive()
return msg['yield'] return msg['yield']
except KeyError: except KeyError as err:
# internal error should never get here # internal error should never get here
assert msg.get('cid'), ("Received internal error at portal?") assert msg.get('cid'), ("Received internal error at portal?")
@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# - 'error' # - 'error'
# possibly just handle msg['stop'] here! # possibly just handle msg['stop'] here!
if msg.get('stop'): if msg.get('stop') or self._eoc:
log.debug(f"{self} was stopped at remote end") log.debug(f"{self} was stopped at remote end")
# XXX: important to set so that a new ``.receive()``
# call (likely by another task using a broadcast receiver)
# doesn't accidentally pull the ``return`` message
# value out of the underlying feed mem chan!
self._eoc = True
# # when the send is closed we assume the stream has # # when the send is closed we assume the stream has
# # terminated and signal this local iterator to stop # # terminated and signal this local iterator to stop
# await self.aclose() # await self.aclose()
@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# XXX: this causes ``ReceiveChannel.__anext__()`` to # XXX: this causes ``ReceiveChannel.__anext__()`` to
# raise a ``StopAsyncIteration`` **and** in our catch # raise a ``StopAsyncIteration`` **and** in our catch
# block below it will trigger ``.aclose()``. # block below it will trigger ``.aclose()``.
raise trio.EndOfChannel raise trio.EndOfChannel from err
# TODO: test that shows stream raising an expected error!!! # TODO: test that shows stream raising an expected error!!!
elif msg.get('error'): elif msg.get('error'):
@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise # propagate raise # propagate
async def aclose(self): async def aclose(self):
"""Cancel associated remote actor task and local memory channel '''
on close. Cancel associated remote actor task and local memory channel on
close.
""" '''
# XXX: keep proper adherance to trio's `.aclose()` semantics: # XXX: keep proper adherance to trio's `.aclose()` semantics:
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
rx_chan = self._rx_chan rx_chan = self._rx_chan
@ -178,7 +186,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
return return
self._eoc = True self._closed = True
# NOTE: this is super subtle IPC messaging stuff: # NOTE: this is super subtle IPC messaging stuff:
# Relay stop iteration to far end **iff** we're # Relay stop iteration to far end **iff** we're
@ -310,15 +318,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
self, self,
data: Any data: Any
) -> None: ) -> None:
'''Send a message over this stream to the far end. '''
Send a message over this stream to the far end.
''' '''
# if self._eoc:
# raise trio.ClosedResourceError('This stream is already ded')
if self._ctx._error: if self._ctx._error:
raise self._ctx._error # from None raise self._ctx._error # from None
if self._closed:
raise trio.ClosedResourceError('This stream was already closed')
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})