From 75bb1addedd7c53516c791971f41e21198f81934 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Wed, 16 Feb 2022 12:08:35 -0500 Subject: [PATCH] Avoid importing mp for as long as possible --- tractor/_mp_fixup_main.py | 9 +++---- tractor/_spawn.py | 52 ++++++++++++++++++++------------------- tractor/_state.py | 2 +- tractor/_supervise.py | 6 +++-- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/tractor/_mp_fixup_main.py b/tractor/_mp_fixup_main.py index 9d0352b..11d5f1c 100644 --- a/tractor/_mp_fixup_main.py +++ b/tractor/_mp_fixup_main.py @@ -18,9 +18,9 @@ Helpers pulled mostly verbatim from ``multiprocessing.spawn`` to aid with "fixing up" the ``__main__`` module in subprocesses. -These helpers are needed for any spawing backend that doesn't already handle this. -For example when using ``trio_run_in_process`` it is needed but obviously not when -we're already using ``multiprocessing``. +These helpers are needed for any spawing backend that doesn't already +handle this. For example when using ``trio_run_in_process`` it is needed +but obviously not when we're already using ``multiprocessing``. """ import os @@ -28,13 +28,12 @@ import sys import platform import types import runpy -from typing import Dict ORIGINAL_DIR = os.path.abspath(os.getcwd()) -def _mp_figure_out_main() -> Dict[str, str]: +def _mp_figure_out_main() -> dict[str, str]: """Taken from ``multiprocessing.spawn.get_preparation_data()``. Retrieve parent actor `__main__` module data. diff --git a/tractor/_spawn.py b/tractor/_spawn.py index 3d7e6b1..c9462df 100644 --- a/tractor/_spawn.py +++ b/tractor/_spawn.py @@ -18,30 +18,17 @@ Machinery for actor process spawning using multiple backends. """ +from __future__ import annotations import sys -import multiprocessing as mp import platform from typing import ( - Any, Dict, Optional, Callable, - TypeVar, + Any, Optional, Callable, TypeVar, TYPE_CHECKING ) from collections.abc import Awaitable import trio from trio_typing import TaskStatus -try: - from multiprocessing import semaphore_tracker # type: ignore - resource_tracker = semaphore_tracker - resource_tracker._resource_tracker = resource_tracker._semaphore_tracker -except ImportError: - # 3.8 introduces a more general version that also tracks shared mems - from multiprocessing import resource_tracker # type: ignore - -from multiprocessing import forkserver # type: ignore -from typing import Tuple - -from . import _forkserver_override from ._debug import ( maybe_wait_for_debugger, acquire_debug_lock, @@ -60,8 +47,11 @@ from ._entry import _mp_main from ._exceptions import ActorFailure +if TYPE_CHECKING: + import multiprocessing as mp + ProcessType = TypeVar('ProcessType', mp.Process, trio.Process) + log = get_logger('tractor') -ProcessType = TypeVar('ProcessType', mp.Process, trio.Process) # placeholder for an mp start context if so using that backend _ctx: Optional[mp.context.BaseContext] = None @@ -92,6 +82,7 @@ def try_set_start_method(name: str) -> Optional[mp.context.BaseContext]: ``subprocess.Popen``. ''' + import multiprocessing as mp global _ctx global _spawn_method @@ -108,6 +99,7 @@ def try_set_start_method(name: str) -> Optional[mp.context.BaseContext]: f"Spawn method `{name}` is invalid please choose one of {methods}" ) elif name == 'forkserver': + from . import _forkserver_override _forkserver_override.override_stdlib() _ctx = mp.get_context(name) elif name == 'trio': @@ -155,7 +147,7 @@ async def cancel_on_completion( portal: Portal, actor: Actor, - errors: Dict[Tuple[str, str], Exception], + errors: dict[tuple[str, str], Exception], ) -> None: ''' @@ -258,12 +250,12 @@ async def new_proc( name: str, actor_nursery: 'ActorNursery', # type: ignore # noqa subactor: Actor, - errors: Dict[Tuple[str, str], Exception], + errors: dict[tuple[str, str], Exception], # passed through to actor main - bind_addr: Tuple[str, int], - parent_addr: Tuple[str, int], - _runtime_vars: Dict[str, Any], # serialized and sent to _child + bind_addr: tuple[str, int], + parent_addr: tuple[str, int], + _runtime_vars: dict[str, Any], # serialized and sent to _child *, @@ -447,20 +439,30 @@ async def mp_new_proc( name: str, actor_nursery: 'ActorNursery', # type: ignore # noqa subactor: Actor, - errors: Dict[Tuple[str, str], Exception], + errors: dict[tuple[str, str], Exception], # passed through to actor main - bind_addr: Tuple[str, int], - parent_addr: Tuple[str, int], - _runtime_vars: Dict[str, Any], # serialized and sent to _child + bind_addr: tuple[str, int], + parent_addr: tuple[str, int], + _runtime_vars: dict[str, Any], # serialized and sent to _child *, infect_asyncio: bool = False, task_status: TaskStatus[Portal] = trio.TASK_STATUS_IGNORED ) -> None: + # uggh zone + try: + from multiprocessing import semaphore_tracker # type: ignore + resource_tracker = semaphore_tracker + resource_tracker._resource_tracker = resource_tracker._semaphore_tracker # noqa + except ImportError: + # 3.8 introduces a more general version that also tracks shared mems + from multiprocessing import resource_tracker # type: ignore + assert _ctx start_method = _ctx.get_start_method() if start_method == 'forkserver': + from multiprocessing import forkserver # type: ignore # XXX do our hackery on the stdlib to avoid multiple # forkservers (one at each subproc layer). fs = forkserver._forkserver diff --git a/tractor/_state.py b/tractor/_state.py index 919c0cf..073bc99 100644 --- a/tractor/_state.py +++ b/tractor/_state.py @@ -20,7 +20,6 @@ Per process state """ from typing import Optional, Dict, Any from collections.abc import Mapping -import multiprocessing as mp import trio @@ -71,6 +70,7 @@ class ActorContextInfo(Mapping): def is_main_process() -> bool: """Bool determining if this actor is running in the top-most process. """ + import multiprocessing as mp return mp.current_process().name == 'MainProcess' diff --git a/tractor/_supervise.py b/tractor/_supervise.py index f2d907d..958d445 100644 --- a/tractor/_supervise.py +++ b/tractor/_supervise.py @@ -20,8 +20,7 @@ """ from functools import partial import inspect -import multiprocessing as mp -from typing import Tuple, List, Dict, Optional +from typing import Tuple, List, Dict, Optional, TYPE_CHECKING import typing import warnings @@ -39,6 +38,9 @@ from . import _state from . import _spawn +if TYPE_CHECKING: + import multiprocessing as mp + log = get_logger(__name__) _default_bind_addr: Tuple[str, int] = ('127.0.0.1', 0)