diff --git a/tests/test_debugger.py b/tests/test_debugger.py index 53c3c84..8a0423b 100644 --- a/tests/test_debugger.py +++ b/tests/test_debugger.py @@ -6,6 +6,7 @@ All these tests can be understood (somewhat) by running the equivalent TODO: None of these tests have been run successfully on windows yet. """ +import time from os import path import pytest @@ -41,10 +42,16 @@ def mk_cmd(ex_name: str) -> str: @pytest.fixture def spawn( + start_method, testdir, arb_addr, ) -> 'pexpect.spawn': + if start_method != 'trio': + pytest.skip( + "Debugger tests are only supported on the trio backend" + ) + def _spawn(cmd): return testdir.spawn( cmd=mk_cmd(cmd), @@ -370,6 +377,8 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method child has unblocked (which can happen when it has the tty lock and is engaged in pdb) it is indeed cancelled after exiting the debugger. """ + timed_out_early = False + child = spawn('root_cancelled_but_child_is_in_tty_lock') child.expect(r"\(Pdb\+\+\)") @@ -377,9 +386,13 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method before = str(child.before.decode()) assert "NameError: name 'doggypants' is not defined" in before assert "tractor._exceptions.RemoteActorError: ('name_error'" not in before + time.sleep(0.5) + child.sendline('c') - for _ in range(4): + + for i in range(4): + time.sleep(0.5) try: child.expect(r"\(Pdb\+\+\)") except TimeoutError: @@ -390,13 +403,26 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method else: raise + except pexpect.exceptions.EOF: + print(f"Failed early on {i}?") + before = str(child.before.decode()) + + timed_out_early = True + + # race conditions on how fast the continue is sent? + break + + before = str(child.before.decode()) assert "NameError: name 'doggypants' is not defined" in before child.sendline('c') child.expect(pexpect.EOF) - before = str(child.before.decode()) - assert "tractor._exceptions.RemoteActorError: ('spawner0'" in before - assert "tractor._exceptions.RemoteActorError: ('name_error'" in before - assert "NameError: name 'doggypants' is not defined" in before + + if not timed_out_early: + + before = str(child.before.decode()) + assert "tractor._exceptions.RemoteActorError: ('spawner0'" in before + assert "tractor._exceptions.RemoteActorError: ('name_error'" in before + assert "NameError: name 'doggypants' is not defined" in before diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 919b278..a64515f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -7,6 +7,7 @@ import platform import trio import tractor +from tractor.testing import tractor_test import pytest @@ -53,6 +54,7 @@ async def stream_from_single_subactor(stream_func_name): """Verify we can spawn a daemon actor and retrieve streamed data. """ async with tractor.find_actor('streamerd') as portals: + if not portals: # only one per host address, spawns an actor if None async with tractor.open_nursery() as nursery: @@ -73,8 +75,10 @@ async def stream_from_single_subactor(stream_func_name): # it'd sure be nice to have an asyncitertools here... iseq = iter(seq) ival = next(iseq) + async for val in stream: assert val == ival + try: ival = next(iseq) except StopIteration: @@ -83,6 +87,7 @@ async def stream_from_single_subactor(stream_func_name): await stream.aclose() await trio.sleep(0.3) + try: await stream.__anext__() except StopAsyncIteration: @@ -109,8 +114,11 @@ def test_stream_from_single_subactor(arb_addr, start_method, stream_func): # this is the first 2 actors, streamer_1 and streamer_2 async def stream_data(seed): + for i in range(seed): + yield i + # trigger scheduler to simulate practical usage await trio.sleep(0) @@ -246,3 +254,68 @@ def test_not_fast_enough_quad( else: # should be cancelled mid-streaming assert results is None + + +@tractor_test +async def test_respawn_consumer_task( + arb_addr, + spawn_backend, + loglevel, +): + """Verify that ``._portal.ReceiveStream.shield()`` + sucessfully protects the underlying IPC channel from being closed + when cancelling and respawning a consumer task. + + This also serves to verify that all values from the stream can be + received despite the respawns. + + """ + stream = None + + async with tractor.open_nursery() as n: + + stream = await(await n.run_in_actor( + 'streamer', + stream_data, + seed=11, + )).result() + + expect = set(range(11)) + received = [] + + # this is the re-spawn task routine + async def consume(task_status=trio.TASK_STATUS_IGNORED): + print('starting consume task..') + nonlocal stream + + with trio.CancelScope() as cs: + task_status.started(cs) + + # shield stream's underlying channel from cancellation + with stream.shield(): + + async for v in stream: + print(f'from stream: {v}') + expect.remove(v) + received.append(v) + + print('exited consume') + + async with trio.open_nursery() as ln: + cs = await ln.start(consume) + + while True: + + await trio.sleep(0.1) + + if received[-1] % 2 == 0: + + print('cancelling consume task..') + cs.cancel() + + # respawn + cs = await ln.start(consume) + + if not expect: + print("all values streamed, BREAKING") + break diff --git a/tractor/_ipc.py b/tractor/_ipc.py index a3a271d..7f6a498 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -214,9 +214,12 @@ class Channel: # # time is pointless # await self.msgstream.send(sent) except trio.BrokenResourceError: + if not self._autorecon: raise + await self.aclose() + if self._autorecon: # attempt reconnect await self._reconnect() continue diff --git a/tractor/_portal.py b/tractor/_portal.py index af3c1b5..15d86e3 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -4,9 +4,10 @@ Portal api import importlib import inspect import typing -from typing import Tuple, Any, Dict, Optional, Set +from typing import Tuple, Any, Dict, Optional, Set, Iterator from functools import partial from dataclasses import dataclass +from contextlib import contextmanager import trio from async_generator import asynccontextmanager @@ -37,7 +38,7 @@ async def maybe_open_nursery( yield nursery -class StreamReceiveChannel(trio.abc.ReceiveChannel): +class ReceiveStream(trio.abc.ReceiveChannel): """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with special behaviour for signalling stream termination across an inter-actor ``Channel``. This is the type returned to a local task @@ -59,6 +60,7 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): self._cid = cid self._rx_chan = rx_chan self._portal = portal + self._shielded = False # delegate directly to underlying mem channel def receive_nowait(self): @@ -83,6 +85,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): "Received internal error at portal?") raise unpack_error(msg, self._portal.channel) + @contextmanager + def shield( + self + ) -> Iterator['ReceiveStream']: # noqa + """Shield this stream's underlying channel such that a local consumer task + can be cancelled (and possibly restarted) using ``trio.Cancelled``. + + """ + self._shielded = True + yield self + self._shielded = False + async def aclose(self): """Cancel associated remote actor task and local memory channel on close. @@ -90,12 +104,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): if self._rx_chan._closed: log.warning(f"{self} is already closed") return + + if self._shielded: + log.warning(f"{self} is shielded, portal channel being kept alive") + return + cid = self._cid with trio.move_on_after(0.5) as cs: cs.shield = True log.warning( f"Cancelling stream {cid} to " f"{self._portal.channel.uid}") + # NOTE: we're telling the far end actor to cancel a task # corresponding to *this actor*. The far end local channel # instance is passed to `Actor._cancel_task()` implicitly. @@ -136,7 +156,7 @@ class Portal: self._expect_result: Optional[ Tuple[str, Any, str, Dict[str, Any]] ] = None - self._streams: Set[StreamReceiveChannel] = set() + self._streams: Set[ReceiveStream] = set() self.actor = current_actor() async def _submit( @@ -199,7 +219,7 @@ class Portal: # to make async-generators the fundamental IPC API over channels! # (think `yield from`, `gen.send()`, and functional reactive stuff) if resptype == 'yield': # stream response - rchan = StreamReceiveChannel(cid, recv_chan, self) + rchan = ReceiveStream(cid, recv_chan, self) self._streams.add(rchan) return rchan @@ -302,7 +322,7 @@ class LocalPortal: A compatibility shim for normal portals but for invoking functions using an in process actor instance. """ - actor: 'Actor' # type: ignore + actor: 'Actor' # type: ignore # noqa channel: Channel async def run(self, ns: str, func_name: str, **kwargs) -> Any: