diff --git a/tractor/_actor.py b/tractor/_actor.py index eb6002e..3d0a4f5 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -140,18 +140,14 @@ async def _invoke( task_status.started(err) finally: # RPC task bookeeping - tasks = actor._rpc_tasks.get(chan, None) - if tasks: - try: - scope, func = tasks.pop(cid) - except ValueError: - # If we're cancelled before the task returns then the - # cancel scope will not have been inserted yet - log.warn( - f"Task {func} was likely cancelled before it was started") - - if not tasks: - actor._rpc_tasks.pop(chan, None) + try: + scope, func, is_complete = actor._rpc_tasks.pop((chan, cid)) + is_complete.set() + except KeyError: + # If we're cancelled before the task returns then the + # cancel scope will not have been inserted yet + log.warn( + f"Task {func} was likely cancelled before it was started") if not actor._rpc_tasks: log.info(f"All RPC tasks have completed") @@ -197,9 +193,10 @@ class Actor: self._no_more_rpc_tasks = trio.Event() self._no_more_rpc_tasks.set() + # (chan, cid) -> (cancel_scope, func) self._rpc_tasks: Dict[ - Channel, - Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]] + Tuple[Channel, str], + Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event] ] = {} # map {uids -> {callids -> waiter queues}} self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} @@ -268,7 +265,7 @@ class Actor: log.warning( f"already have channel(s) for {uid}:{chans}?" ) - log.debug(f"Registered {chan} for {uid}") + log.trace(f"Registered {chan} for {uid}") # append new channel self._peers[uid].append(chan) @@ -295,8 +292,9 @@ class Actor: if chan.connected(): log.debug(f"Disconnecting channel {chan}") try: + # send our msg loop terminate sentinel await chan.send(None) - await chan.aclose() + # await chan.aclose() except trio.BrokenResourceError: log.exception( f"Channel for {chan.uid} was already zonked..") @@ -334,7 +332,10 @@ class Actor: return cid, q async def _process_messages( - self, chan: Channel, treat_as_gen: bool = False + self, chan: Channel, + treat_as_gen: bool = False, + shield: bool = False, + task_status=trio.TASK_STATUS_IGNORED, ) -> None: """Process messages for the channel async-RPC style. @@ -342,91 +343,90 @@ class Actor: """ # TODO: once https://github.com/python-trio/trio/issues/467 gets # worked out we'll likely want to use that! + msg = None log.debug(f"Entering msg loop for {chan} from {chan.uid}") try: - async for msg in chan: - if msg is None: # terminate sentinel - log.debug( - f"Cancelling all tasks for {chan} from {chan.uid}") - for cid, (scope, func) in self._rpc_tasks.pop( - chan, {} - ).items(): - scope.cancel() - log.debug( - f"Msg loop signalled to terminate for" - f" {chan} from {chan.uid}") - break - log.debug(f"Received msg {msg} from {chan.uid}") - cid = msg.get('cid') - if cid: - cancel = msg.get('cancel') - if cancel: - # right now this is only implicitly used by - # async generator IPC - scope, func = self._rpc_tasks[chan][cid] + # internal scope allows for keeping this message + # loop running despite the current task having been + # cancelled (eg. `open_portal()` may call this method from + # a locally spawned task) + with trio.open_cancel_scope(shield=shield) as cs: + task_status.started(cs) + async for msg in chan: + if msg is None: # loop terminate sentinel log.debug( - f"Received cancel request for task {cid}" - f" from {chan.uid}") - scope.cancel() - else: + f"Cancelling all tasks for {chan} from {chan.uid}") + for (channel, cid) in self._rpc_tasks: + if channel is chan: + self.cancel_task(cid, Context(channel, cid)) + log.debug( + f"Msg loop signalled to terminate for" + f" {chan} from {chan.uid}") + break + + log.debug(f"Received msg {msg} from {chan.uid}") + cid = msg.get('cid') + if cid: # deliver response to local caller/waiter await self._push_result(chan.uid, cid, msg) log.debug( f"Waiting on next msg for {chan} from {chan.uid}") - continue - - # process command request - try: - ns, funcname, kwargs, actorid, cid = msg['cmd'] - except KeyError: - # This is the non-rpc error case, that is, an - # error **not** raised inside a call to ``_invoke()`` - # (i.e. no cid was provided in the msg - see above). - # Push this error to all local channel consumers - # (normally portals) by marking the channel as errored - assert chan.uid - exc = unpack_error(msg, chan=chan) - chan._exc = exc - raise exc - - log.debug( - f"Processing request from {actorid}\n" - f"{ns}.{funcname}({kwargs})") - if ns == 'self': - func = getattr(self, funcname) - else: - # complain to client about restricted modules - try: - func = self._get_rpc_func(ns, funcname) - except (ModuleNotExposed, AttributeError) as err: - err_msg = pack_error(err) - err_msg['cid'] = cid - await chan.send(err_msg) continue - # spin up a task for the requested function - log.debug(f"Spawning task for {func}") - cs = await self._root_nursery.start( - _invoke, self, cid, chan, func, kwargs, - name=funcname - ) - # never allow cancelling cancel requests (results in - # deadlock and other weird behaviour) - if func != self.cancel: - if isinstance(cs, Exception): - log.warn(f"Task for RPC func {func} failed with {cs}") + # process command request + try: + ns, funcname, kwargs, actorid, cid = msg['cmd'] + except KeyError: + # This is the non-rpc error case, that is, an + # error **not** raised inside a call to ``_invoke()`` + # (i.e. no cid was provided in the msg - see above). + # Push this error to all local channel consumers + # (normally portals) by marking the channel as errored + assert chan.uid + exc = unpack_error(msg, chan=chan) + chan._exc = exc + raise exc + + log.debug( + f"Processing request from {actorid}\n" + f"{ns}.{funcname}({kwargs})") + if ns == 'self': + func = getattr(self, funcname) else: - # mark that we have ongoing rpc tasks - self._no_more_rpc_tasks.clear() - log.info(f"RPC func is {func}") - # store cancel scope such that the rpc task can be - # cancelled gracefully if requested - self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func) - log.debug( - f"Waiting on next msg for {chan} from {chan.uid}") - else: - # channel disconnect - log.debug(f"{chan} from {chan.uid} disconnected") + # complain to client about restricted modules + try: + func = self._get_rpc_func(ns, funcname) + except (ModuleNotExposed, AttributeError) as err: + err_msg = pack_error(err) + err_msg['cid'] = cid + await chan.send(err_msg) + continue + + # spin up a task for the requested function + log.debug(f"Spawning task for {func}") + cs = await self._root_nursery.start( + _invoke, self, cid, chan, func, kwargs, + name=funcname + ) + # never allow cancelling cancel requests (results in + # deadlock and other weird behaviour) + if func != self.cancel: + if isinstance(cs, Exception): + log.warn(f"Task for RPC func {func} failed with" + f"{cs}") + else: + # mark that we have ongoing rpc tasks + self._no_more_rpc_tasks.clear() + log.info(f"RPC func is {func}") + # store cancel scope such that the rpc task can be + # cancelled gracefully if requested + self._rpc_tasks[(chan, cid)] = ( + cs, func, trio.Event()) + log.debug( + f"Waiting on next msg for {chan} from {chan.uid}") + else: + # channel disconnect + log.debug(f"{chan} from {chan.uid} disconnected") except trio.ClosedResourceError: log.error(f"{chan} form {chan.uid} broke") @@ -439,8 +439,14 @@ class Actor: raise # if this is the `MainProcess` we expect the error broadcasting # above to trigger an error at consuming portal "checkpoints" + except trio.Cancelled: + # debugging only + log.debug("Msg loop was cancelled") + raise finally: - log.debug(f"Exiting msg loop for {chan} from {chan.uid}") + log.debug( + f"Exiting msg loop for {chan} from {chan.uid} " + f"with last msg:\n{msg}") def _fork_main( self, @@ -541,8 +547,7 @@ class Actor: if self._parent_chan: try: # internal error so ship to parent without cid - await self._parent_chan.send( - pack_error(err)) + await self._parent_chan.send(pack_error(err)) except trio.ClosedResourceError: log.error( f"Failed to ship error to parent " @@ -627,21 +632,47 @@ class Actor: self.cancel_server() self._root_nursery.cancel_scope.cancel() + async def cancel_task(self, cid, ctx): + """Cancel a local task. + + Note this method will be treated as a streaming funciton + by remote actor-callers due to the declaration of ``ctx`` + in the signature (for now). + """ + # right now this is only implicitly called by + # streaming IPC but it should be called + # to cancel any remotely spawned task + chan = ctx.chan + # the ``dict.get()`` ensures the requested task to be cancelled + # was indeed spawned by a request from this channel + scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)] + log.debug( + f"Cancelling task:\ncid: {cid}\nfunc: {func}\n" + f"peer: {chan.uid}\n") + + # if func is self.cancel_task: + # return + + scope.cancel() + # wait for _invoke to mark the task complete + await is_complete.wait() + log.debug( + f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n" + f"peer: {chan.uid}\n") + async def cancel_rpc_tasks(self) -> None: """Cancel all existing RPC responder tasks using the cancel scope registered for each. """ tasks = self._rpc_tasks - log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}") - for chan, cids2scopes in tasks.items(): - log.debug(f"Cancelling all tasks for {chan.uid}") - for cid, (scope, func) in cids2scopes.items(): - log.debug(f"Cancelling task for {func}") - scope.cancel() - if tasks: - log.info( - f"Waiting for remaining rpc tasks to complete {tasks}") - await self._no_more_rpc_tasks.wait() + log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ") + for (chan, cid) in tasks.copy(): + # TODO: this should really done in a nursery batch + await self.cancel_task(cid, Context(chan, cid)) + # if tasks: + log.info( + f"Waiting for remaining rpc tasks to complete {tasks}") + await self._no_more_rpc_tasks.wait() def cancel_server(self) -> None: """Cancel the internal channel server nursery thereby @@ -810,7 +841,9 @@ async def find_actor( sockaddr = await arb_portal.run('self', 'find_actor', name=name) # TODO: return portals to all available actors - for now just # the last one that registered - if sockaddr: + if name == 'arbiter' and actor.is_arbiter: + raise RuntimeError("The current actor is the arbiter") + elif sockaddr: async with _connect_chan(*sockaddr) as chan: async with open_portal(chan) as portal: yield portal