diff --git a/tractor/_streaming.py b/tractor/_streaming.py index ec15272..bb86dcd 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -27,7 +27,8 @@ from typing import ( Optional, Callable, AsyncGenerator, - AsyncIterator + AsyncIterator, + TYPE_CHECKING, ) import warnings @@ -41,6 +42,10 @@ from .log import get_logger from .trionics import broadcast_receiver, BroadcastReceiver +if TYPE_CHECKING: + from ._portal import Portal + + log = get_logger(__name__) @@ -378,7 +383,8 @@ class Context: _remote_func_type: Optional[str] = None # only set on the caller side - _portal: Optional['Portal'] = None # type: ignore # noqa + _portal: Optional[Portal] = None # type: ignore # noqa + _stream: Optional[MsgStream] = None _result: Optional[Any] = False _error: Optional[BaseException] = None @@ -486,6 +492,7 @@ class Context: log.cancel(f'Cancelling {side} side of context to {self.chan.uid}') self._cancel_called = True + ipc_broken: bool = False if side == 'caller': if not self._portal: @@ -503,7 +510,14 @@ class Context: # NOTE: we're telling the far end actor to cancel a task # corresponding to *this actor*. The far end local channel # instance is passed to `Actor._cancel_task()` implicitly. - await self._portal.run_from_ns('self', '_cancel_task', cid=cid) + try: + await self._portal.run_from_ns( + 'self', + '_cancel_task', + cid=cid, + ) + except trio.BrokenResourceError: + ipc_broken = True if cs.cancelled_caught: # XXX: there's no way to know if the remote task was indeed @@ -519,7 +533,10 @@ class Context: "Timed out on cancelling remote task " f"{cid} for {self._portal.channel.uid}") - # callee side remote task + elif ipc_broken: + log.cancel( + "Transport layer was broken before cancel request " + f"{cid} for {self._portal.channel.uid}") else: self._cancel_msg = msg @@ -607,6 +624,7 @@ class Context: ctx=self, rx_chan=ctx._recv_chan, ) as stream: + self._stream = stream if self._portal: self._portal._streams.add(stream) @@ -648,25 +666,22 @@ class Context: if not self._recv_chan._closed: # type: ignore - # wait for a final context result consuming - # and discarding any bi dir stream msgs still - # in transit from the far end. - while True: + def consume( + msg: dict, - msg = await self._recv_chan.receive() + ) -> Optional[dict]: try: - self._result = msg['return'] - break + return msg['return'] except KeyError as msgerr: if 'yield' in msg: # far end task is still streaming to us so discard log.warning(f'Discarding stream delivered {msg}') - continue + return elif 'stop' in msg: log.debug('Remote stream terminated') - continue + return # internal error should never get here assert msg.get('cid'), ( @@ -676,6 +691,24 @@ class Context: msg, self._portal.channel ) from msgerr + # wait for a final context result consuming + # and discarding any bi dir stream msgs still + # in transit from the far end. + if self._stream: + async with self._stream.subscribe() as bstream: + async for msg in bstream: + result = consume(msg) + if result: + self._result = result + + if not self._result: + while True: + msg = await self._recv_chan.receive() + result = consume(msg) + if result: + self._result = result + break + return self._result async def started(