Be more explicit with naming and stdlib override

forkserver_singleton
Tyler Goodlet 2018-07-27 10:47:23 -04:00
parent 7017f68503
commit 4b875f0ade
3 changed files with 36 additions and 26 deletions

View File

@ -165,7 +165,7 @@ class Actor:
self._listeners = []
self._parent_chan = None
self._accept_host = None
self._fs_deats = None
self._forkserver_info = None
async def wait_for_peer(self, uid):
"""Wait for a connection back from a spawned actor with a given
@ -362,10 +362,10 @@ class Actor:
finally:
log.debug(f"Exiting msg loop for {chan} from {chan.uid}")
def _fork_main(self, accept_addr, fs_deats, parent_addr=None):
def _fork_main(self, accept_addr, forkserver_info, parent_addr=None):
# after fork routine which invokes a fresh ``trio.run``
# log.warn("Log level after fork is {self.loglevel}")
self._fs_deats = fs_deats
self._forkserver_info = forkserver_info
from ._trionics import ctx
if self.loglevel is not None:
get_console_log(self.loglevel)

View File

@ -256,20 +256,29 @@ class AdultSemaphoreTracker(semaphore_tracker.SemaphoreTracker):
return self._fd
# override the stdlib's stuff
_semaphore_tracker = AdultSemaphoreTracker()
semaphore_tracker._semaphore_tracker = _semaphore_tracker
semaphore_tracker.ensure_running = _semaphore_tracker.ensure_running
semaphore_tracker.register = _semaphore_tracker.register
semaphore_tracker.unregister = _semaphore_tracker.unregister
semaphore_tracker.getfd = _semaphore_tracker.getfd
_forkserver = AdultForkServer()
forkserver._forkserver = _forkserver
forkserver.main = main
forkserver._serve_one = _serve_one
forkserver.ensure_running = _forkserver.ensure_running
forkserver.get_inherited_fds = _forkserver.get_inherited_fds
forkserver.connect_to_new_process = _forkserver.connect_to_new_process
forkserver.set_forkserver_preload = _forkserver.set_forkserver_preload
def override_stdlib():
"""Override the stdlib's ``multiprocessing.forkserver`` behaviour
such that our local "manually managed" version from above can be
used instead.
This allows avoiding spawning superfluous additional forkservers
and semaphore trackers for each actor nursery used to create new
sub-actors from sub-actors.
"""
semaphore_tracker._semaphore_tracker = _semaphore_tracker
semaphore_tracker.ensure_running = _semaphore_tracker.ensure_running
semaphore_tracker.register = _semaphore_tracker.register
semaphore_tracker.unregister = _semaphore_tracker.unregister
semaphore_tracker.getfd = _semaphore_tracker.getfd
forkserver._forkserver = _forkserver
forkserver.main = main
forkserver._serve_one = _serve_one
forkserver.ensure_running = _forkserver.ensure_running
forkserver.get_inherited_fds = _forkserver.get_inherited_fds
forkserver.connect_to_new_process = _forkserver.connect_to_new_process
forkserver.set_forkserver_preload = _forkserver.set_forkserver_preload

View File

@ -8,13 +8,14 @@ from multiprocessing import forkserver, semaphore_tracker
import trio
from async_generator import asynccontextmanager, aclosing
from . import _forkserver_hackzorz # overrides stdlib
from . import _forkserver_hackzorz
from ._state import current_actor
from .log import get_logger, get_loglevel
from ._actor import Actor, ActorFailure
from ._portal import Portal
_forkserver_hackzorz.override_stdlib()
ctx = mp.get_context("forkserver")
log = get_logger('tractor')
@ -29,7 +30,7 @@ class ActorNursery:
# portals spawned with ``run_in_actor()``
self._cancel_after_result_on_exit = set()
self.cancelled = False
self._fs = None
self._forkserver = None
async def __aenter__(self):
return self
@ -53,16 +54,16 @@ class ActorNursery:
)
parent_addr = self._actor.accept_addr
assert parent_addr
self._fs = fs = forkserver._forkserver
self._forkserver = fs = forkserver._forkserver
if mp.current_process().name == 'MainProcess' and (
not self._actor._fs_deats
not self._actor._forkserver_info
):
# if we're the "main" process start the forkserver only once
# and pass it's ipc info to downstream children
# forkserver.set_forkserver_preload(rpc_module_paths)
forkserver.ensure_running()
fs_deats = addr, alive_fd, pid, st_pid, st_fd = (
fs_info = addr, alive_fd, pid, st_pid, st_fd = (
fs._forkserver_address,
fs._forkserver_alive_fd,
getattr(fs, '_forkserver_pid', None),
@ -70,17 +71,17 @@ class ActorNursery:
semaphore_tracker._semaphore_tracker._fd,
)
else:
fs_deats = (
fs_info = (
fs._forkserver_address,
fs._forkserver_alive_fd,
fs._forkserver_pid,
semaphore_tracker._semaphore_tracker._pid,
semaphore_tracker._semaphore_tracker._fd,
) = self._actor._fs_deats
) = self._actor._forkserver_info
proc = ctx.Process(
target=actor._fork_main,
args=(bind_addr, fs_deats, parent_addr),
args=(bind_addr, fs_info, parent_addr),
# daemon=True,
name=name,
)