Mark stream with EOC when stop message is received

end_of_channel_fixes
Tyler Goodlet 2021-12-15 16:22:04 -05:00
parent 79d63585b0
commit f2ba961e81
1 changed files with 20 additions and 11 deletions

View File

@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# flag to denote end of stream
self._eoc: bool = False
self._closed: bool = False
# delegate directly to underlying mem channel
def receive_nowait(self):
@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
msg = await self._rx_chan.receive()
return msg['yield']
except KeyError:
except KeyError as err:
# internal error should never get here
assert msg.get('cid'), ("Received internal error at portal?")
@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# - 'error'
# possibly just handle msg['stop'] here!
if msg.get('stop'):
if msg.get('stop') or self._eoc:
log.debug(f"{self} was stopped at remote end")
# XXX: important to set so that a new ``.receive()``
# call (likely by another task using a broadcast receiver)
# doesn't accidentally pull the ``return`` message
# value out of the underlying feed mem chan!
self._eoc = True
# # when the send is closed we assume the stream has
# # terminated and signal this local iterator to stop
# await self.aclose()
@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# XXX: this causes ``ReceiveChannel.__anext__()`` to
# raise a ``StopAsyncIteration`` **and** in our catch
# block below it will trigger ``.aclose()``.
raise trio.EndOfChannel
raise trio.EndOfChannel from err
# TODO: test that shows stream raising an expected error!!!
elif msg.get('error'):
@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise # propagate
async def aclose(self):
"""Cancel associated remote actor task and local memory channel
on close.
'''
Cancel associated remote actor task and local memory channel on
close.
"""
'''
# XXX: keep proper adherance to trio's `.aclose()` semantics:
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
rx_chan = self._rx_chan
@ -178,7 +186,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
return
self._eoc = True
self._closed = True
# NOTE: this is super subtle IPC messaging stuff:
# Relay stop iteration to far end **iff** we're
@ -310,15 +318,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
self,
data: Any
) -> None:
'''Send a message over this stream to the far end.
'''
Send a message over this stream to the far end.
'''
# if self._eoc:
# raise trio.ClosedResourceError('This stream is already ded')
if self._ctx._error:
raise self._ctx._error # from None
if self._closed:
raise trio.ClosedResourceError('This stream was already closed')
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})