From c9eb466d760508ec78c8cd3501ff33e2e17e36c2 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Mon, 14 Mar 2022 07:08:22 -0400 Subject: [PATCH] Use `MsgStream.subscribe()` in `Context.result()` The case exists where there is multiple tasks consuming from an open 2-way stream created via `Context.open_stream()` where a sibling task is pulling from the stream while some other task also calls `.result()`. Previously the `.result()` call would consume (drain) stream messages directly from the underlying mem chan which would mean any sibling task would not receive those same messages. Instead, make `.result()` check if a stream is open and instead consume (and discard) stream msgs using a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all interested tasks get copies of the same packets. --- tractor/_streaming.py | 59 +++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/tractor/_streaming.py b/tractor/_streaming.py index ec15272..bb86dcd 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -27,7 +27,8 @@ from typing import ( Optional, Callable, AsyncGenerator, - AsyncIterator + AsyncIterator, + TYPE_CHECKING, ) import warnings @@ -41,6 +42,10 @@ from .log import get_logger from .trionics import broadcast_receiver, BroadcastReceiver +if TYPE_CHECKING: + from ._portal import Portal + + log = get_logger(__name__) @@ -378,7 +383,8 @@ class Context: _remote_func_type: Optional[str] = None # only set on the caller side - _portal: Optional['Portal'] = None # type: ignore # noqa + _portal: Optional[Portal] = None # type: ignore # noqa + _stream: Optional[MsgStream] = None _result: Optional[Any] = False _error: Optional[BaseException] = None @@ -486,6 +492,7 @@ class Context: log.cancel(f'Cancelling {side} side of context to {self.chan.uid}') self._cancel_called = True + ipc_broken: bool = False if side == 'caller': if not self._portal: @@ -503,7 +510,14 @@ class Context: # NOTE: we're telling the far end actor to cancel a task # corresponding to *this actor*. The far end local channel # instance is passed to `Actor._cancel_task()` implicitly. - await self._portal.run_from_ns('self', '_cancel_task', cid=cid) + try: + await self._portal.run_from_ns( + 'self', + '_cancel_task', + cid=cid, + ) + except trio.BrokenResourceError: + ipc_broken = True if cs.cancelled_caught: # XXX: there's no way to know if the remote task was indeed @@ -519,7 +533,10 @@ class Context: "Timed out on cancelling remote task " f"{cid} for {self._portal.channel.uid}") - # callee side remote task + elif ipc_broken: + log.cancel( + "Transport layer was broken before cancel request " + f"{cid} for {self._portal.channel.uid}") else: self._cancel_msg = msg @@ -607,6 +624,7 @@ class Context: ctx=self, rx_chan=ctx._recv_chan, ) as stream: + self._stream = stream if self._portal: self._portal._streams.add(stream) @@ -648,25 +666,22 @@ class Context: if not self._recv_chan._closed: # type: ignore - # wait for a final context result consuming - # and discarding any bi dir stream msgs still - # in transit from the far end. - while True: + def consume( + msg: dict, - msg = await self._recv_chan.receive() + ) -> Optional[dict]: try: - self._result = msg['return'] - break + return msg['return'] except KeyError as msgerr: if 'yield' in msg: # far end task is still streaming to us so discard log.warning(f'Discarding stream delivered {msg}') - continue + return elif 'stop' in msg: log.debug('Remote stream terminated') - continue + return # internal error should never get here assert msg.get('cid'), ( @@ -676,6 +691,24 @@ class Context: msg, self._portal.channel ) from msgerr + # wait for a final context result consuming + # and discarding any bi dir stream msgs still + # in transit from the far end. + if self._stream: + async with self._stream.subscribe() as bstream: + async for msg in bstream: + result = consume(msg) + if result: + self._result = result + + if not self._result: + while True: + msg = await self._recv_chan.receive() + result = consume(msg) + if result: + self._result = result + break + return self._result async def started(