diff --git a/tractor/_spawn.py b/tractor/_spawn.py index 06f6532..b3657b7 100644 --- a/tractor/_spawn.py +++ b/tractor/_spawn.py @@ -18,17 +18,30 @@ Machinery for actor process spawning using multiple backends. """ -from __future__ import annotations import sys +import multiprocessing as mp import platform from typing import ( - Any, Optional, Callable, TypeVar, TYPE_CHECKING + Any, Dict, Optional, Callable, + TypeVar, ) from collections.abc import Awaitable import trio from trio_typing import TaskStatus +try: + from multiprocessing import semaphore_tracker # type: ignore + resource_tracker = semaphore_tracker + resource_tracker._resource_tracker = resource_tracker._semaphore_tracker +except ImportError: + # 3.8 introduces a more general version that also tracks shared mems + from multiprocessing import resource_tracker # type: ignore + +from multiprocessing import forkserver # type: ignore +from typing import Tuple + +from . import _forkserver_override from ._debug import ( maybe_wait_for_debugger, acquire_debug_lock, @@ -47,11 +60,8 @@ from ._entry import _mp_main from ._exceptions import ActorFailure -if TYPE_CHECKING: - import multiprocessing as mp - ProcessType = TypeVar('ProcessType', mp.Process, trio.Process) - log = get_logger('tractor') +ProcessType = TypeVar('ProcessType', mp.Process, trio.Process) # placeholder for an mp start context if so using that backend _ctx: Optional[mp.context.BaseContext] = None @@ -60,7 +70,6 @@ _spawn_method: str = "trio" if platform.system() == 'Windows': - import multiprocessing as mp _ctx = mp.get_context("spawn") async def proc_waiter(proc: mp.Process) -> None: @@ -83,7 +92,6 @@ def try_set_start_method(name: str) -> Optional[mp.context.BaseContext]: ``subprocess.Popen``. ''' - import multiprocessing as mp global _ctx global _spawn_method @@ -100,7 +108,6 @@ def try_set_start_method(name: str) -> Optional[mp.context.BaseContext]: f"Spawn method `{name}` is invalid please choose one of {methods}" ) elif name == 'forkserver': - from . import _forkserver_override _forkserver_override.override_stdlib() _ctx = mp.get_context(name) elif name == 'trio': @@ -148,7 +155,7 @@ async def cancel_on_completion( portal: Portal, actor: Actor, - errors: dict[tuple[str, str], Exception], + errors: Dict[Tuple[str, str], Exception], ) -> None: ''' @@ -251,12 +258,12 @@ async def new_proc( name: str, actor_nursery: 'ActorNursery', # type: ignore # noqa subactor: Actor, - errors: dict[tuple[str, str], Exception], + errors: Dict[Tuple[str, str], Exception], # passed through to actor main - bind_addr: tuple[str, int], - parent_addr: tuple[str, int], - _runtime_vars: dict[str, Any], # serialized and sent to _child + bind_addr: Tuple[str, int], + parent_addr: Tuple[str, int], + _runtime_vars: Dict[str, Any], # serialized and sent to _child *, @@ -288,7 +295,7 @@ async def new_proc( # the OS; it otherwise can be passed via the parent channel if # we prefer in the future (for privacy). "--uid", - str(subactor.uid), + str(uid), # Address the child must connect to on startup "--parent_addr", str(parent_addr) @@ -314,8 +321,7 @@ async def new_proc( # wait for actor to spawn and connect back to us # channel should have handshake completed by the # local actor by the time we get a ref to it - event, chan = await actor_nursery._actor.wait_for_peer( - subactor.uid) + event, chan = await actor_nursery._actor.wait_for_peer(uid) except trio.Cancelled: cancelled_during_spawn = True @@ -356,10 +362,54 @@ async def new_proc( task_status.started(portal) # wait for ActorNursery.wait() to be called + n_exited = actor_nursery._join_procs with trio.CancelScope(shield=True): - await actor_nursery._join_procs.wait() + await n_exited.wait() async with trio.open_nursery() as nursery: + + async def soft_wait_and_maybe_cancel_ria_task(): + # This is a "soft" (cancellable) join/reap which + # will remote cancel the actor on a ``trio.Cancelled`` + # condition. + await soft_wait( + proc, + trio.Process.wait, + portal + ) + + if n_exited.is_set(): + # cancel result waiter that may have been spawned in + # tandem if not done already + log.warning( + "Cancelling existing result waiter task for " + f"{subactor.uid}" + ) + nursery.cancel_scope.cancel() + + else: + log.warning( + f'Process for actor {uid} terminated before' + 'nursery exit. ' 'This may mean an IPC' + 'connection failed!' + ) + + nursery.start_soon(soft_wait_and_maybe_cancel_ria_task) + + # TODO: when we finally remove the `.run_in_actor()` api + # we should be able to entirely drop these 2 blocking calls: + # - we don't need to wait on nursery exit to capture + # process-spawn-machinery level errors (and propagate them). + # - we don't need to wait on final results from ria portals + # since this will be done in some higher level wrapper API. + + # XXX: interestingly we can't put this here bc it causes + # the pub-sub tests to fail? wth.. should probably drop + # those XD + # wait for ActorNursery.wait() to be called + # with trio.CancelScope(shield=True): + # await n_exited.wait() + if portal in actor_nursery._cancel_after_result_on_exit: nursery.start_soon( cancel_on_completion, @@ -368,22 +418,6 @@ async def new_proc( errors ) - # This is a "soft" (cancellable) join/reap which - # will remote cancel the actor on a ``trio.Cancelled`` - # condition. - await soft_wait( - proc, - trio.Process.wait, - portal - ) - - # cancel result waiter that may have been spawned in - # tandem if not done already - log.warning( - "Cancelling existing result waiter task for " - f"{subactor.uid}") - nursery.cancel_scope.cancel() - finally: # The "hard" reap since no actor zombies are allowed! # XXX: do this **after** cancellation/tearfown to avoid @@ -400,9 +434,10 @@ async def new_proc( await proc.wait() if is_root_process(): + await maybe_wait_for_debugger( child_in_debug=_runtime_vars.get( - '_debug_mode', False), + '_debug_mode', False) ) if proc.poll() is None: @@ -441,30 +476,20 @@ async def mp_new_proc( name: str, actor_nursery: 'ActorNursery', # type: ignore # noqa subactor: Actor, - errors: dict[tuple[str, str], Exception], + errors: Dict[Tuple[str, str], Exception], # passed through to actor main - bind_addr: tuple[str, int], - parent_addr: tuple[str, int], - _runtime_vars: dict[str, Any], # serialized and sent to _child + bind_addr: Tuple[str, int], + parent_addr: Tuple[str, int], + _runtime_vars: Dict[str, Any], # serialized and sent to _child *, infect_asyncio: bool = False, task_status: TaskStatus[Portal] = trio.TASK_STATUS_IGNORED ) -> None: - # uggh zone - try: - from multiprocessing import semaphore_tracker # type: ignore - resource_tracker = semaphore_tracker - resource_tracker._resource_tracker = resource_tracker._semaphore_tracker # noqa - except ImportError: - # 3.8 introduces a more general version that also tracks shared mems - from multiprocessing import resource_tracker # type: ignore - assert _ctx start_method = _ctx.get_start_method() if start_method == 'forkserver': - from multiprocessing import forkserver # type: ignore # XXX do our hackery on the stdlib to avoid multiple # forkservers (one at each subproc layer). fs = forkserver._forkserver