Be more explicit with naming and stdlib override

forkserver_singleton
Tyler Goodlet 2018-07-27 10:47:23 -04:00
parent 7017f68503
commit 4b875f0ade
3 changed files with 36 additions and 26 deletions

View File

@ -165,7 +165,7 @@ class Actor:
self._listeners = [] self._listeners = []
self._parent_chan = None self._parent_chan = None
self._accept_host = None self._accept_host = None
self._fs_deats = None self._forkserver_info = None
async def wait_for_peer(self, uid): async def wait_for_peer(self, uid):
"""Wait for a connection back from a spawned actor with a given """Wait for a connection back from a spawned actor with a given
@ -362,10 +362,10 @@ class Actor:
finally: finally:
log.debug(f"Exiting msg loop for {chan} from {chan.uid}") log.debug(f"Exiting msg loop for {chan} from {chan.uid}")
def _fork_main(self, accept_addr, fs_deats, parent_addr=None): def _fork_main(self, accept_addr, forkserver_info, parent_addr=None):
# after fork routine which invokes a fresh ``trio.run`` # after fork routine which invokes a fresh ``trio.run``
# log.warn("Log level after fork is {self.loglevel}") # log.warn("Log level after fork is {self.loglevel}")
self._fs_deats = fs_deats self._forkserver_info = forkserver_info
from ._trionics import ctx from ._trionics import ctx
if self.loglevel is not None: if self.loglevel is not None:
get_console_log(self.loglevel) get_console_log(self.loglevel)

View File

@ -256,16 +256,25 @@ class AdultSemaphoreTracker(semaphore_tracker.SemaphoreTracker):
return self._fd return self._fd
# override the stdlib's stuff
_semaphore_tracker = AdultSemaphoreTracker() _semaphore_tracker = AdultSemaphoreTracker()
_forkserver = AdultForkServer()
def override_stdlib():
"""Override the stdlib's ``multiprocessing.forkserver`` behaviour
such that our local "manually managed" version from above can be
used instead.
This allows avoiding spawning superfluous additional forkservers
and semaphore trackers for each actor nursery used to create new
sub-actors from sub-actors.
"""
semaphore_tracker._semaphore_tracker = _semaphore_tracker semaphore_tracker._semaphore_tracker = _semaphore_tracker
semaphore_tracker.ensure_running = _semaphore_tracker.ensure_running semaphore_tracker.ensure_running = _semaphore_tracker.ensure_running
semaphore_tracker.register = _semaphore_tracker.register semaphore_tracker.register = _semaphore_tracker.register
semaphore_tracker.unregister = _semaphore_tracker.unregister semaphore_tracker.unregister = _semaphore_tracker.unregister
semaphore_tracker.getfd = _semaphore_tracker.getfd semaphore_tracker.getfd = _semaphore_tracker.getfd
_forkserver = AdultForkServer()
forkserver._forkserver = _forkserver forkserver._forkserver = _forkserver
forkserver.main = main forkserver.main = main
forkserver._serve_one = _serve_one forkserver._serve_one = _serve_one

View File

@ -8,13 +8,14 @@ from multiprocessing import forkserver, semaphore_tracker
import trio import trio
from async_generator import asynccontextmanager, aclosing from async_generator import asynccontextmanager, aclosing
from . import _forkserver_hackzorz # overrides stdlib from . import _forkserver_hackzorz
from ._state import current_actor from ._state import current_actor
from .log import get_logger, get_loglevel from .log import get_logger, get_loglevel
from ._actor import Actor, ActorFailure from ._actor import Actor, ActorFailure
from ._portal import Portal from ._portal import Portal
_forkserver_hackzorz.override_stdlib()
ctx = mp.get_context("forkserver") ctx = mp.get_context("forkserver")
log = get_logger('tractor') log = get_logger('tractor')
@ -29,7 +30,7 @@ class ActorNursery:
# portals spawned with ``run_in_actor()`` # portals spawned with ``run_in_actor()``
self._cancel_after_result_on_exit = set() self._cancel_after_result_on_exit = set()
self.cancelled = False self.cancelled = False
self._fs = None self._forkserver = None
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -53,16 +54,16 @@ class ActorNursery:
) )
parent_addr = self._actor.accept_addr parent_addr = self._actor.accept_addr
assert parent_addr assert parent_addr
self._fs = fs = forkserver._forkserver self._forkserver = fs = forkserver._forkserver
if mp.current_process().name == 'MainProcess' and ( if mp.current_process().name == 'MainProcess' and (
not self._actor._fs_deats not self._actor._forkserver_info
): ):
# if we're the "main" process start the forkserver only once # if we're the "main" process start the forkserver only once
# and pass it's ipc info to downstream children # and pass it's ipc info to downstream children
# forkserver.set_forkserver_preload(rpc_module_paths) # forkserver.set_forkserver_preload(rpc_module_paths)
forkserver.ensure_running() forkserver.ensure_running()
fs_deats = addr, alive_fd, pid, st_pid, st_fd = ( fs_info = addr, alive_fd, pid, st_pid, st_fd = (
fs._forkserver_address, fs._forkserver_address,
fs._forkserver_alive_fd, fs._forkserver_alive_fd,
getattr(fs, '_forkserver_pid', None), getattr(fs, '_forkserver_pid', None),
@ -70,17 +71,17 @@ class ActorNursery:
semaphore_tracker._semaphore_tracker._fd, semaphore_tracker._semaphore_tracker._fd,
) )
else: else:
fs_deats = ( fs_info = (
fs._forkserver_address, fs._forkserver_address,
fs._forkserver_alive_fd, fs._forkserver_alive_fd,
fs._forkserver_pid, fs._forkserver_pid,
semaphore_tracker._semaphore_tracker._pid, semaphore_tracker._semaphore_tracker._pid,
semaphore_tracker._semaphore_tracker._fd, semaphore_tracker._semaphore_tracker._fd,
) = self._actor._fs_deats ) = self._actor._forkserver_info
proc = ctx.Process( proc = ctx.Process(
target=actor._fork_main, target=actor._fork_main,
args=(bind_addr, fs_deats, parent_addr), args=(bind_addr, fs_info, parent_addr),
# daemon=True, # daemon=True,
name=name, name=name,
) )