diff --git a/tractor/_actor.py b/tractor/_actor.py index 7693035..fe38929 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -165,6 +165,7 @@ class Actor: self._listeners = [] self._parent_chan = None self._accept_host = None + self._forkserver_info = None async def wait_for_peer(self, uid): """Wait for a connection back from a spawned actor with a given @@ -361,9 +362,10 @@ class Actor: finally: log.debug(f"Exiting msg loop for {chan} from {chan.uid}") - def _fork_main(self, accept_addr, parent_addr=None): + def _fork_main(self, accept_addr, forkserver_info, parent_addr=None): # after fork routine which invokes a fresh ``trio.run`` # log.warn("Log level after fork is {self.loglevel}") + self._forkserver_info = forkserver_info from ._trionics import ctx if self.loglevel is not None: get_console_log(self.loglevel) diff --git a/tractor/_forkserver_hackzorz.py b/tractor/_forkserver_hackzorz.py new file mode 100644 index 0000000..30b919d --- /dev/null +++ b/tractor/_forkserver_hackzorz.py @@ -0,0 +1,284 @@ +""" +This is near-copy of the 3.8 stdlib's ``multiprocessing.forkserver.py`` +with some hackery to prevent any more then a single forkserver and +semaphore tracker per ``MainProcess``. +""" +import os +import socket +import signal +import struct +import sys +import errno +import selectors +import warnings + +from multiprocessing import ( + forkserver, semaphore_tracker, spawn, process, util +) +from multiprocessing.forkserver import ( + ForkServer, MAXFDS_TO_SEND, write_unsigned, + # _serve_one, +) +from multiprocessing.context import reduction + + +# taken from 3.8 +SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t + + +class PatchedForkServer(ForkServer): + + def connect_to_new_process(self, fds): + '''Request forkserver to create a child process. + + Returns a pair of fds (status_r, data_w). The calling process can read + the child process's pid and (eventually) its returncode from status_r. + The calling process should write to data_w the pickled preparation and + process data. + ''' + # self.ensure_running() # treat our users like adults expecting + # them to spawn the server on their own + if len(fds) + 4 >= MAXFDS_TO_SEND: + raise ValueError('too many fds') + with socket.socket(socket.AF_UNIX) as client: + client.connect(self._forkserver_address) + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + allfds = [child_r, child_w, self._forkserver_alive_fd, + semaphore_tracker.getfd()] + allfds += fds + + # This is the only part changed + try: + while True: + try: + reduction.sendfds(client, allfds) + break + except OSError as err: + if err.errno == errno.EBADF: + print(f"Bad FD {err}") + client = socket.socket(socket.AF_UNIX) + client.connect(self._forkserver_address) + continue + raise + # This is the only part changed + + return parent_r, parent_w + except Exception: + os.close(parent_r) + os.close(parent_w) + raise + finally: + os.close(child_r) + os.close(child_w) + + +def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): + '''Run forkserver.''' + if preload: + if '__main__' in preload and main_path is not None: + process.current_process()._inheriting = True + try: + spawn.import_main_path(main_path) + finally: + del process.current_process()._inheriting + for modname in preload: + try: + __import__(modname) + except ImportError: + pass + + util._close_stdin() + + sig_r, sig_w = os.pipe() + os.set_blocking(sig_r, False) + os.set_blocking(sig_w, False) + + def sigchld_handler(*_unused): + # Dummy signal handler, doesn't do anything + pass + + handlers = { + # unblocking SIGCHLD allows the wakeup fd to notify our event loop + signal.SIGCHLD: sigchld_handler, + # protect the process from ^C + signal.SIGINT: signal.SIG_IGN, + } + old_handlers = {sig: signal.signal(sig, val) + for (sig, val) in handlers.items()} + + # calling os.write() in the Python signal handler is racy + signal.set_wakeup_fd(sig_w) + + # map child pids to client fds + pid_to_fd = {} + + with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ + selectors.DefaultSelector() as selector: + _forkserver._forkserver_address = listener.getsockname() + + selector.register(listener, selectors.EVENT_READ) + selector.register(alive_r, selectors.EVENT_READ) + selector.register(sig_r, selectors.EVENT_READ) + + while True: + try: + while True: + rfds = [key.fileobj for (key, events) in selector.select()] + if rfds: + break + + if alive_r in rfds: + # EOF because no more client processes left + assert os.read(alive_r, 1) == b'', "Not at EOF?" + raise SystemExit + + if sig_r in rfds: + # Got SIGCHLD + os.read(sig_r, 65536) # exhaust + while True: + # Scan for child processes + try: + pid, sts = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + break + child_w = pid_to_fd.pop(pid, None) + if child_w is not None: + if os.WIFSIGNALED(sts): + returncode = -os.WTERMSIG(sts) + else: + if not os.WIFEXITED(sts): + raise AssertionError( + "Child {0:n} status is {1:n}".format( + pid, sts)) + returncode = os.WEXITSTATUS(sts) + # Send exit code to client process + try: + # write_signed(child_w, returncode) + write_unsigned(child_w, returncode) + except BrokenPipeError: + # client vanished + pass + os.close(child_w) + else: + # This shouldn't happen really + warnings.warn('forkserver: waitpid returned ' + 'unexpected pid %d' % pid) + + if listener in rfds: + # Incoming fork request + with listener.accept()[0] as s: + + # Thing changed - be tolerant of socket disconnects + try: + # Receive fds from client + fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + except EOFError: + # broken socket due to reconnection on client-side + continue + # Thing changed - be tolerant of socket disconnects + + if len(fds) > MAXFDS_TO_SEND: + raise RuntimeError( + "Too many ({0:n}) fds to send".format( + len(fds))) + child_r, child_w, *fds = fds + s.close() + pid = os.fork() + if pid == 0: + # Child + code = 1 + try: + listener.close() + selector.close() + unused_fds = [alive_r, child_w, sig_r, sig_w] + unused_fds.extend(pid_to_fd.values()) + code = _serve_one(child_r, fds, + unused_fds, + old_handlers) + except Exception: + sys.excepthook(*sys.exc_info()) + sys.stderr.flush() + finally: + os._exit(code) + else: + # Send pid to client process + try: + # write_signed(child_w, pid) + write_unsigned(child_w, pid) + except BrokenPipeError: + # client vanished + pass + pid_to_fd[pid] = child_w + os.close(child_r) + for fd in fds: + os.close(fd) + + except OSError as e: + if e.errno != errno.ECONNABORTED: + raise + + +def _serve_one(child_r, fds, unused_fds, handlers): + # close unnecessary stuff and reset signal handlers + signal.set_wakeup_fd(-1) + for sig, val in handlers.items(): + signal.signal(sig, val) + for fd in unused_fds: + os.close(fd) + + (_forkserver._forkserver_alive_fd, + semaphore_tracker._semaphore_tracker._fd, + *_forkserver._inherited_fds) = fds + + # Run process object received over pipe + code = spawn._main(child_r) + + return code + + +def write_signed(fd, n): + msg = SIGNED_STRUCT.pack(n) + while msg: + nbytes = os.write(fd, msg) + if nbytes == 0: + raise RuntimeError('should not get here') + msg = msg[nbytes:] + + +class PatchedSemaphoreTracker(semaphore_tracker.SemaphoreTracker): + """Stop GD ensuring everything is running... + """ + def getfd(self): + # self.ensure_running() + return self._fd + + +_semaphore_tracker = PatchedSemaphoreTracker() +_forkserver = PatchedForkServer() + + +def override_stdlib(): + """Override the stdlib's ``multiprocessing.forkserver`` behaviour + such that our local "manually managed" version from above can be + used instead. + + This allows avoiding spawning superfluous additional forkservers + and semaphore trackers for each actor nursery used to create new + sub-actors from sub-actors. + """ + semaphore_tracker._semaphore_tracker = _semaphore_tracker + semaphore_tracker.ensure_running = _semaphore_tracker.ensure_running + semaphore_tracker.register = _semaphore_tracker.register + semaphore_tracker.unregister = _semaphore_tracker.unregister + semaphore_tracker.getfd = _semaphore_tracker.getfd + + forkserver._forkserver = _forkserver + forkserver.main = main + forkserver._serve_one = _serve_one + forkserver.ensure_running = _forkserver.ensure_running + forkserver.get_inherited_fds = _forkserver.get_inherited_fds + forkserver.connect_to_new_process = _forkserver.connect_to_new_process + forkserver.set_forkserver_preload = _forkserver.set_forkserver_preload diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 56c09db..c8a811b 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -57,7 +57,7 @@ class StreamQueue: async def get(self): return await self._agen.asend(None) - async def __aiter__(self): + def __aiter__(self): return self._agen def connected(self): diff --git a/tractor/_portal.py b/tractor/_portal.py index ebae431..d9ee957 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -173,7 +173,7 @@ class Portal: # send cancel cmd - might not get response await self.run('self', 'cancel') return True - except trio.ClosedStreamError: + except trio.ClosedResourceError: log.warn( f"{self.channel} for {self.channel.uid} was already closed?") return False diff --git a/tractor/_trionics.py b/tractor/_trionics.py index aa9becb..569f731 100644 --- a/tractor/_trionics.py +++ b/tractor/_trionics.py @@ -3,16 +3,19 @@ """ import multiprocessing as mp import inspect +from multiprocessing import forkserver, semaphore_tracker import trio from async_generator import asynccontextmanager, aclosing +from . import _forkserver_hackzorz from ._state import current_actor from .log import get_logger, get_loglevel from ._actor import Actor, ActorFailure from ._portal import Portal +_forkserver_hackzorz.override_stdlib() ctx = mp.get_context("forkserver") log = get_logger('tractor') @@ -27,6 +30,7 @@ class ActorNursery: # portals spawned with ``run_in_actor()`` self._cancel_after_result_on_exit = set() self.cancelled = False + self._forkserver = None async def __aenter__(self): return self @@ -50,9 +54,33 @@ class ActorNursery: ) parent_addr = self._actor.accept_addr assert parent_addr + self._forkserver = fs = forkserver._forkserver + if mp.current_process().name == 'MainProcess' and ( + not self._actor._forkserver_info + ): + # if we're the "main" process start the forkserver only once + # and pass its ipc info to downstream children + # forkserver.set_forkserver_preload(rpc_module_paths) + forkserver.ensure_running() + fs_info = ( + fs._forkserver_address, + fs._forkserver_alive_fd, + getattr(fs, '_forkserver_pid', None), + getattr(semaphore_tracker._semaphore_tracker, '_pid', None), + semaphore_tracker._semaphore_tracker._fd, + ) + else: + fs_info = ( + fs._forkserver_address, + fs._forkserver_alive_fd, + fs._forkserver_pid, + semaphore_tracker._semaphore_tracker._pid, + semaphore_tracker._semaphore_tracker._fd, + ) = self._actor._forkserver_info + proc = ctx.Process( target=actor._fork_main, - args=(bind_addr, parent_addr), + args=(bind_addr, fs_info, parent_addr), # daemon=True, name=name, )