diff --git a/tractor/_forkserver_hackzorz.py b/tractor/_forkserver_hackzorz.py index d04f262..ec60b5c 100644 --- a/tractor/_forkserver_hackzorz.py +++ b/tractor/_forkserver_hackzorz.py @@ -13,7 +13,8 @@ import selectors import warnings from multiprocessing import ( - forkserver, semaphore_tracker, spawn, process, util + forkserver, semaphore_tracker, spawn, process, util, + connection ) from multiprocessing.forkserver import ( ForkServer, MAXFDS_TO_SEND @@ -72,6 +73,64 @@ class PatchedForkServer(ForkServer): os.close(child_r) os.close(child_w) + def ensure_running(self): + '''Make sure that a fork server is running. + + This can be called from any process. Note that usually a child + process will just reuse the forkserver started by its parent, so + ensure_running() will do nothing. + ''' + with self._lock: + semaphore_tracker.ensure_running() + if self._forkserver_pid is not None: + # forkserver was launched before, is it still running? + pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) + if not pid: + # still alive + return + # dead, launch it again + os.close(self._forkserver_alive_fd) + self._forkserver_address = None + self._forkserver_alive_fd = None + self._forkserver_pid = None + + # XXX only thing that changed! + cmd = ('from tractor._forkserver_hackzorz import main; ' + + 'main(%d, %d, %r, **%r)') + + if self._preload_modules: + desired_keys = {'main_path', 'sys_path'} + data = spawn.get_preparation_data('ignore') + data = {x: y for x, y in data.items() if x in desired_keys} + else: + data = {} + + with socket.socket(socket.AF_UNIX) as listener: + address = connection.arbitrary_address('AF_UNIX') + listener.bind(address) + os.chmod(address, 0o600) + listener.listen() + + # all client processes own the write end of the "alive" pipe; + # when they all terminate the read end becomes ready. + alive_r, alive_w = os.pipe() + try: + fds_to_pass = [listener.fileno(), alive_r] + cmd %= (listener.fileno(), alive_r, self._preload_modules, + data) + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd] + pid = util.spawnv_passfds(exe, args, fds_to_pass) + except: + os.close(alive_w) + raise + finally: + os.close(alive_r) + self._forkserver_address = address + self._forkserver_alive_fd = alive_w + self._forkserver_pid = pid + def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): '''Run forkserver.''' @@ -274,9 +333,9 @@ def override_stdlib(): semaphore_tracker.getfd = _semaphore_tracker.getfd forkserver._forkserver = _forkserver + forkserver.ensure_running = _forkserver.ensure_running forkserver.main = main forkserver._serve_one = _serve_one - forkserver.ensure_running = _forkserver.ensure_running forkserver.get_inherited_fds = _forkserver.get_inherited_fds forkserver.connect_to_new_process = _forkserver.connect_to_new_process forkserver.set_forkserver_preload = _forkserver.set_forkserver_preload