Add "spawn" start method support

Add full support for using the "spawn" process starting method as per:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

Add a  `spawn_method` argument to `tractor.run()` for specifying the
desired method explicitly. By default use the "fastest" method available.
On *nix systems this is the original "forkserver" method.

This should be the solution to getting windows support!

Resolves #60
spawn_method_support
Tyler Goodlet 2019-03-06 00:29:07 -05:00
parent d75739e9c7
commit 7014a07986
3 changed files with 70 additions and 30 deletions

View File

@ -19,6 +19,7 @@ from ._trionics import open_nursery
from ._state import current_actor from ._state import current_actor
from ._exceptions import RemoteActorError, ModuleNotExposed from ._exceptions import RemoteActorError, ModuleNotExposed
from . import msg from . import msg
from . import _spawn
__all__ = [ __all__ = [
@ -92,12 +93,14 @@ def run(
name: str = None, name: str = None,
arbiter_addr: Tuple[str, int] = ( arbiter_addr: Tuple[str, int] = (
_default_arbiter_host, _default_arbiter_port), _default_arbiter_host, _default_arbiter_port),
spawn_method: str = 'forkserver',
**kwargs: typing.Dict[str, typing.Any], **kwargs: typing.Dict[str, typing.Any],
) -> Any: ) -> Any:
"""Run a trio-actor async function in process. """Run a trio-actor async function in process.
This is tractor's main entry and the start point for any async actor. This is tractor's main entry and the start point for any async actor.
""" """
_spawn.try_set_start_method(spawn_method)
return trio.run(_main, async_fn, args, kwargs, name, arbiter_addr) return trio.run(_main, async_fn, args, kwargs, name, arbiter_addr)

View File

@ -391,7 +391,8 @@ class Actor:
f" {chan} from {chan.uid}") f" {chan} from {chan.uid}")
break break
log.trace(f"Received msg {msg} from {chan.uid}") # type: ignore log.trace( # type: ignore
f"Received msg {msg} from {chan.uid}")
if msg.get('cid'): if msg.get('cid'):
# deliver response to local caller/waiter # deliver response to local caller/waiter
await self._push_result(chan, msg) await self._push_result(chan, msg)
@ -478,18 +479,20 @@ class Actor:
self, self,
accept_addr: Tuple[str, int], accept_addr: Tuple[str, int],
forkserver_info: Tuple[Any, Any, Any, Any, Any], forkserver_info: Tuple[Any, Any, Any, Any, Any],
start_method: str,
parent_addr: Tuple[str, int] = None parent_addr: Tuple[str, int] = None
) -> None: ) -> None:
"""The routine called *after fork* which invokes a fresh ``trio.run`` """The routine called *after fork* which invokes a fresh ``trio.run``
""" """
self._forkserver_info = forkserver_info self._forkserver_info = forkserver_info
from ._spawn import ctx from ._spawn import try_set_start_method
spawn_ctx = try_set_start_method(start_method)
if self.loglevel is not None: if self.loglevel is not None:
log.info( log.info(
f"Setting loglevel for {self.uid} to {self.loglevel}") f"Setting loglevel for {self.uid} to {self.loglevel}")
get_console_log(self.loglevel) get_console_log(self.loglevel)
log.info( log.info(
f"Started new {ctx.current_process()} for {self.uid}") f"Started new {spawn_ctx.current_process()} for {self.uid}")
_state._current_actor = self _state._current_actor = self
log.debug(f"parent_addr is {parent_addr}") log.debug(f"parent_addr is {parent_addr}")
try: try:

View File

@ -5,15 +5,35 @@ Mostly just wrapping around ``multiprocessing``.
""" """
import multiprocessing as mp import multiprocessing as mp
from multiprocessing import forkserver, semaphore_tracker # type: ignore from multiprocessing import forkserver, semaphore_tracker # type: ignore
from typing import Tuple from typing import Tuple, Optional
from . import _forkserver_hackzorz from . import _forkserver_hackzorz
from ._state import current_actor from ._state import current_actor
from ._actor import Actor from ._actor import Actor
_forkserver_hackzorz.override_stdlib() _ctx: mp.context.BaseContext = mp.get_context("spawn")
ctx = mp.get_context("forkserver")
def try_set_start_method(name: str) -> mp.context.BaseContext:
"""Attempt to set the start method for ``multiprocess.Process`` spawning.
If the desired method is not supported the sub-interpreter (aka "spawn"
method) is used.
"""
global _ctx
allowed = mp.get_all_start_methods()
if name not in allowed:
name == 'spawn'
assert name in allowed
if name == 'forkserver':
_forkserver_hackzorz.override_stdlib()
_ctx = mp.get_context(name)
return _ctx
def is_main_process() -> bool: def is_main_process() -> bool:
@ -29,33 +49,47 @@ def new_proc(
bind_addr: Tuple[str, int], bind_addr: Tuple[str, int],
parent_addr: Tuple[str, int], parent_addr: Tuple[str, int],
) -> mp.Process: ) -> mp.Process:
fs = forkserver._forkserver """Create a new ``multiprocessing.Process`` using the
curr_actor = current_actor() spawn method as configured using ``try_set_start_method()``.
if is_main_process() and not curr_actor._forkserver_info: """
# if we're the "main" process start the forkserver only once start_method = _ctx.get_start_method()
# and pass its ipc info to downstream children if start_method == 'forkserver':
# forkserver.set_forkserver_preload(rpc_module_paths) # XXX do our hackery on the stdlib to avoid multiple
forkserver.ensure_running() # forkservers (one at each subproc layer).
fs_info = ( fs = forkserver._forkserver
fs._forkserver_address, curr_actor = current_actor()
fs._forkserver_alive_fd, if is_main_process() and not curr_actor._forkserver_info:
getattr(fs, '_forkserver_pid', None), # if we're the "main" process start the forkserver only once
getattr(semaphore_tracker._semaphore_tracker, '_pid', None), # and pass its ipc info to downstream children
semaphore_tracker._semaphore_tracker._fd, # forkserver.set_forkserver_preload(rpc_module_paths)
) forkserver.ensure_running()
fs_info = (
fs._forkserver_address,
fs._forkserver_alive_fd,
getattr(fs, '_forkserver_pid', None),
getattr(semaphore_tracker._semaphore_tracker, '_pid', None),
semaphore_tracker._semaphore_tracker._fd,
)
else:
assert curr_actor._forkserver_info
fs_info = (
fs._forkserver_address,
fs._forkserver_alive_fd,
fs._forkserver_pid,
semaphore_tracker._semaphore_tracker._pid,
semaphore_tracker._semaphore_tracker._fd,
) = curr_actor._forkserver_info
else: else:
assert curr_actor._forkserver_info fs_info = (None, None, None, None, None)
fs_info = (
fs._forkserver_address,
fs._forkserver_alive_fd,
fs._forkserver_pid,
semaphore_tracker._semaphore_tracker._pid,
semaphore_tracker._semaphore_tracker._fd,
) = curr_actor._forkserver_info
return ctx.Process( return _ctx.Process(
target=actor._fork_main, target=actor._fork_main,
args=(bind_addr, fs_info, parent_addr), args=(
bind_addr,
fs_info,
start_method,
parent_addr
),
# daemon=True, # daemon=True,
name=name, name=name,
) )