diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5691fbd..e25a000 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,6 +13,12 @@ async def stream_seq(sequence): yield i await trio.sleep(0.1) + # block indefinitely waiting to be cancelled by ``aclose()`` call + with trio.open_cancel_scope() as cs: + await trio.sleep(float('inf')) + assert 0 + assert cs.cancelled_caught + async def stream_from_single_subactor(): """Verify we can spawn a daemon actor and retrieve streamed data. @@ -37,17 +43,22 @@ async def stream_from_single_subactor(): ) # it'd sure be nice to have an asyncitertools here... iseq = iter(seq) + ival = next(iseq) async for val in agen: - assert val == next(iseq) - # TODO: test breaking the loop (should it kill the - # far end?) - # break - # terminate far-end async-gen - # await gen.asend(None) - # break + assert val == ival + try: + ival = next(iseq) + except StopIteration: + # should cancel far end task which will be + # caught and no error is raised + await agen.aclose() - # stop all spawned subactors - await portal.cancel_actor() + await trio.sleep(0.3) + try: + await agen.__anext__() + except StopAsyncIteration: + # stop all spawned subactors + await portal.cancel_actor() # await nursery.cancel() diff --git a/tractor/_actor.py b/tractor/_actor.py index f1beb45..750487a 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -103,7 +103,7 @@ async def _invoke( # TODO: we should really support a proper # `StopAsyncIteration` system here for returning a final # value if desired - await chan.send({'stop': None, 'cid': cid}) + await chan.send({'stop': True, 'cid': cid}) else: if treat_as_gen: await chan.send({'functype': 'asyncgen', 'cid': cid}) @@ -117,7 +117,7 @@ async def _invoke( if not cs.cancelled_caught: # task was not cancelled so we can instruct the # far end async gen to tear down - await chan.send({'stop': None, 'cid': cid}) + await chan.send({'stop': True, 'cid': cid}) else: await chan.send({'functype': 'asyncfunction', 'cid': cid}) with trio.open_cancel_scope() as cs: @@ -141,7 +141,7 @@ async def _invoke( tasks = actor._rpc_tasks.get(chan, None) if tasks: try: - tasks.remove((cs, func)) + scope, func = tasks.pop(cid) except ValueError: # If we're cancelled before the task returns then the # cancel scope will not have been inserted yet @@ -197,7 +197,7 @@ class Actor: self._no_more_rpc_tasks.set() self._rpc_tasks: Dict[ Channel, - List[Tuple[trio._core._run.CancelScope, typing.Callable]] + Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]] ] = {} # map {uids -> {callids -> waiter queues}} self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} @@ -294,8 +294,11 @@ class Actor: # # XXX: is this necessary (GC should do it?) if chan.connected(): log.debug(f"Disconnecting channel {chan}") - await chan.send(None) - await chan.aclose() + try: + await chan.send(None) + await chan.aclose() + except trio.BrokenResourceError: + log.exception(f"Channel for {chan.uid} was already zonked..") async def _push_result(self, actorid, cid: str, msg: dict) -> None: """Push an RPC result to the local consumer's queue. @@ -344,9 +347,10 @@ class Actor: if msg is None: # terminate sentinel log.debug( f"Cancelling all tasks for {chan} from {chan.uid}") - for scope, func in self._rpc_tasks.pop(chan, ()): + for cid, (scope, func) in self._rpc_tasks.pop( + chan, {} + ).items(): scope.cancel() - log.debug( f"Msg loop signalled to terminate for" f" {chan} from {chan.uid}") @@ -354,10 +358,20 @@ class Actor: log.debug(f"Received msg {msg} from {chan.uid}") cid = msg.get('cid') if cid: - # deliver response to local caller/waiter - await self._push_result(chan.uid, cid, msg) - log.debug( - f"Waiting on next msg for {chan} from {chan.uid}") + cancel = msg.get('cancel') + if cancel: + # right now this is only implicitly used by + # async generator IPC + scope, func = self._rpc_tasks[chan][cid] + log.debug( + f"Received cancel request for task {cid}" + f" from {chan.uid}") + scope.cancel() + else: + # deliver response to local caller/waiter + await self._push_result(chan.uid, cid, msg) + log.debug( + f"Waiting on next msg for {chan} from {chan.uid}") continue # process command request @@ -403,7 +417,7 @@ class Actor: log.info(f"RPC func is {func}") # store cancel scope such that the rpc task can be # cancelled gracefully if requested - self._rpc_tasks.setdefault(chan, []).append((cs, func)) + self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func) log.debug( f"Waiting on next msg for {chan} from {chan.uid}") else: @@ -611,9 +625,9 @@ class Actor: """ tasks = self._rpc_tasks log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}") - for chan, scopes in tasks.items(): + for chan, cids2scopes in tasks.items(): log.debug(f"Cancelling all tasks for {chan.uid}") - for scope, func in scopes: + for cid, (scope, func) in cids2scopes.items(): log.debug(f"Cancelling task for {func}") scope.cancel() if tasks: diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 52b1ca5..f7bebbe 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -20,7 +20,7 @@ class StreamQueue: self._agen = self._iter_packets() self._laddr = self.stream.socket.getsockname()[:2] self._raddr = self.stream.socket.getpeername()[:2] - self._send_lock = trio.Lock() + self._send_lock = trio.StrictFIFOLock() async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: """Yield packets from the underlying stream. diff --git a/tractor/_portal.py b/tractor/_portal.py index 28e28f4..fff39db 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -4,7 +4,7 @@ Portal api import importlib import inspect import typing -from typing import Tuple, Any, Dict, Optional +from typing import Tuple, Any, Dict, Optional, Set import trio from async_generator import asynccontextmanager @@ -65,6 +65,7 @@ class Portal: self._expect_result: Optional[ Tuple[str, Any, str, Dict[str, Any]] ] = None + self._agens: Set[typing.AsyncGenerator] = set() async def aclose(self) -> None: log.debug(f"Closing {self}") @@ -139,15 +140,19 @@ class Portal: "Received internal error at portal?") raise unpack_error(msg, self.channel) - except StopAsyncIteration: - log.debug( + except GeneratorExit: + # for now this msg cancels an ongoing remote task + await self.channel.send({'cancel': True, 'cid': cid}) + log.warn( f"Cancelling async gen call {cid} to " f"{self.channel.uid}") raise # TODO: use AsyncExitStack to aclose() all agens # on teardown - return yield_from_q() + agen = yield_from_q() + self._agens.add(agen) + return agen elif resptype == 'return': msg = await q.get() @@ -267,13 +272,18 @@ async def open_portal( nursery.start_soon(actor._process_messages, channel) portal = Portal(channel) - yield portal + try: + yield portal + finally: + # tear down all async generators + for agen in portal._agens: + await agen.aclose() - # cancel remote channel-msg loop - if channel.connected(): - await portal.close() + # cancel remote channel-msg loop + if channel.connected(): + await portal.close() - # cancel background msg loop task - nursery.cancel_scope.cancel() - if was_connected: - await channel.aclose() + # cancel background msg loop task + nursery.cancel_scope.cancel() + if was_connected: + await channel.aclose()