Add `Actor.cancel_task()`

Enable cancelling specific tasks from a peer actor such that when
a actor task or the actor itself is cancelled, remotely spawned tasks
can also be cancelled. In much that same way that you'd expect a node
(task) in the `trio` task tree to cancel any subtasks, actors should
be able to cancel any tasks they spawn in separate processes.

To enable this:
- track rpc tasks in a flat dict keyed by (chan, cid)
- store a `is_complete` event to enable waiting on specific
  tasks to complete
- allow for shielding the msg loop inside an internal cancel scope
  if requested by the caller; there was an issue with `open_portal()`
  where the channel would be torn down because the current task was
  cancelled but we still need messaging to continue until the portal
  block is exited
- throw an error if the arbiter tries to find itself for now
contexts
Tyler Goodlet 2019-01-21 00:16:20 -05:00
parent 251ee177fa
commit 03e00886da
1 changed files with 139 additions and 106 deletions

View File

@ -140,18 +140,14 @@ async def _invoke(
task_status.started(err) task_status.started(err)
finally: finally:
# RPC task bookeeping # RPC task bookeeping
tasks = actor._rpc_tasks.get(chan, None) try:
if tasks: scope, func, is_complete = actor._rpc_tasks.pop((chan, cid))
try: is_complete.set()
scope, func = tasks.pop(cid) except KeyError:
except ValueError: # If we're cancelled before the task returns then the
# If we're cancelled before the task returns then the # cancel scope will not have been inserted yet
# cancel scope will not have been inserted yet log.warn(
log.warn( f"Task {func} was likely cancelled before it was started")
f"Task {func} was likely cancelled before it was started")
if not tasks:
actor._rpc_tasks.pop(chan, None)
if not actor._rpc_tasks: if not actor._rpc_tasks:
log.info(f"All RPC tasks have completed") log.info(f"All RPC tasks have completed")
@ -197,9 +193,10 @@ class Actor:
self._no_more_rpc_tasks = trio.Event() self._no_more_rpc_tasks = trio.Event()
self._no_more_rpc_tasks.set() self._no_more_rpc_tasks.set()
# (chan, cid) -> (cancel_scope, func)
self._rpc_tasks: Dict[ self._rpc_tasks: Dict[
Channel, Tuple[Channel, str],
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]] Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event]
] = {} ] = {}
# map {uids -> {callids -> waiter queues}} # map {uids -> {callids -> waiter queues}}
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
@ -268,7 +265,7 @@ class Actor:
log.warning( log.warning(
f"already have channel(s) for {uid}:{chans}?" f"already have channel(s) for {uid}:{chans}?"
) )
log.debug(f"Registered {chan} for {uid}") log.trace(f"Registered {chan} for {uid}")
# append new channel # append new channel
self._peers[uid].append(chan) self._peers[uid].append(chan)
@ -295,8 +292,9 @@ class Actor:
if chan.connected(): if chan.connected():
log.debug(f"Disconnecting channel {chan}") log.debug(f"Disconnecting channel {chan}")
try: try:
# send our msg loop terminate sentinel
await chan.send(None) await chan.send(None)
await chan.aclose() # await chan.aclose()
except trio.BrokenResourceError: except trio.BrokenResourceError:
log.exception( log.exception(
f"Channel for {chan.uid} was already zonked..") f"Channel for {chan.uid} was already zonked..")
@ -334,7 +332,10 @@ class Actor:
return cid, q return cid, q
async def _process_messages( async def _process_messages(
self, chan: Channel, treat_as_gen: bool = False self, chan: Channel,
treat_as_gen: bool = False,
shield: bool = False,
task_status=trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
"""Process messages for the channel async-RPC style. """Process messages for the channel async-RPC style.
@ -342,91 +343,90 @@ class Actor:
""" """
# TODO: once https://github.com/python-trio/trio/issues/467 gets # TODO: once https://github.com/python-trio/trio/issues/467 gets
# worked out we'll likely want to use that! # worked out we'll likely want to use that!
msg = None
log.debug(f"Entering msg loop for {chan} from {chan.uid}") log.debug(f"Entering msg loop for {chan} from {chan.uid}")
try: try:
async for msg in chan: # internal scope allows for keeping this message
if msg is None: # terminate sentinel # loop running despite the current task having been
log.debug( # cancelled (eg. `open_portal()` may call this method from
f"Cancelling all tasks for {chan} from {chan.uid}") # a locally spawned task)
for cid, (scope, func) in self._rpc_tasks.pop( with trio.open_cancel_scope(shield=shield) as cs:
chan, {} task_status.started(cs)
).items(): async for msg in chan:
scope.cancel() if msg is None: # loop terminate sentinel
log.debug(
f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}")
break
log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid')
if cid:
cancel = msg.get('cancel')
if cancel:
# right now this is only implicitly used by
# async generator IPC
scope, func = self._rpc_tasks[chan][cid]
log.debug( log.debug(
f"Received cancel request for task {cid}" f"Cancelling all tasks for {chan} from {chan.uid}")
f" from {chan.uid}") for (channel, cid) in self._rpc_tasks:
scope.cancel() if channel is chan:
else: self.cancel_task(cid, Context(channel, cid))
log.debug(
f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}")
break
log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid')
if cid:
# deliver response to local caller/waiter # deliver response to local caller/waiter
await self._push_result(chan.uid, cid, msg) await self._push_result(chan.uid, cid, msg)
log.debug( log.debug(
f"Waiting on next msg for {chan} from {chan.uid}") f"Waiting on next msg for {chan} from {chan.uid}")
continue
# process command request
try:
ns, funcname, kwargs, actorid, cid = msg['cmd']
except KeyError:
# This is the non-rpc error case, that is, an
# error **not** raised inside a call to ``_invoke()``
# (i.e. no cid was provided in the msg - see above).
# Push this error to all local channel consumers
# (normally portals) by marking the channel as errored
assert chan.uid
exc = unpack_error(msg, chan=chan)
chan._exc = exc
raise exc
log.debug(
f"Processing request from {actorid}\n"
f"{ns}.{funcname}({kwargs})")
if ns == 'self':
func = getattr(self, funcname)
else:
# complain to client about restricted modules
try:
func = self._get_rpc_func(ns, funcname)
except (ModuleNotExposed, AttributeError) as err:
err_msg = pack_error(err)
err_msg['cid'] = cid
await chan.send(err_msg)
continue continue
# spin up a task for the requested function # process command request
log.debug(f"Spawning task for {func}") try:
cs = await self._root_nursery.start( ns, funcname, kwargs, actorid, cid = msg['cmd']
_invoke, self, cid, chan, func, kwargs, except KeyError:
name=funcname # This is the non-rpc error case, that is, an
) # error **not** raised inside a call to ``_invoke()``
# never allow cancelling cancel requests (results in # (i.e. no cid was provided in the msg - see above).
# deadlock and other weird behaviour) # Push this error to all local channel consumers
if func != self.cancel: # (normally portals) by marking the channel as errored
if isinstance(cs, Exception): assert chan.uid
log.warn(f"Task for RPC func {func} failed with {cs}") exc = unpack_error(msg, chan=chan)
chan._exc = exc
raise exc
log.debug(
f"Processing request from {actorid}\n"
f"{ns}.{funcname}({kwargs})")
if ns == 'self':
func = getattr(self, funcname)
else: else:
# mark that we have ongoing rpc tasks # complain to client about restricted modules
self._no_more_rpc_tasks.clear() try:
log.info(f"RPC func is {func}") func = self._get_rpc_func(ns, funcname)
# store cancel scope such that the rpc task can be except (ModuleNotExposed, AttributeError) as err:
# cancelled gracefully if requested err_msg = pack_error(err)
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func) err_msg['cid'] = cid
log.debug( await chan.send(err_msg)
f"Waiting on next msg for {chan} from {chan.uid}") continue
else:
# channel disconnect # spin up a task for the requested function
log.debug(f"{chan} from {chan.uid} disconnected") log.debug(f"Spawning task for {func}")
cs = await self._root_nursery.start(
_invoke, self, cid, chan, func, kwargs,
name=funcname
)
# never allow cancelling cancel requests (results in
# deadlock and other weird behaviour)
if func != self.cancel:
if isinstance(cs, Exception):
log.warn(f"Task for RPC func {func} failed with"
f"{cs}")
else:
# mark that we have ongoing rpc tasks
self._no_more_rpc_tasks.clear()
log.info(f"RPC func is {func}")
# store cancel scope such that the rpc task can be
# cancelled gracefully if requested
self._rpc_tasks[(chan, cid)] = (
cs, func, trio.Event())
log.debug(
f"Waiting on next msg for {chan} from {chan.uid}")
else:
# channel disconnect
log.debug(f"{chan} from {chan.uid} disconnected")
except trio.ClosedResourceError: except trio.ClosedResourceError:
log.error(f"{chan} form {chan.uid} broke") log.error(f"{chan} form {chan.uid} broke")
@ -439,8 +439,14 @@ class Actor:
raise raise
# if this is the `MainProcess` we expect the error broadcasting # if this is the `MainProcess` we expect the error broadcasting
# above to trigger an error at consuming portal "checkpoints" # above to trigger an error at consuming portal "checkpoints"
except trio.Cancelled:
# debugging only
log.debug("Msg loop was cancelled")
raise
finally: finally:
log.debug(f"Exiting msg loop for {chan} from {chan.uid}") log.debug(
f"Exiting msg loop for {chan} from {chan.uid} "
f"with last msg:\n{msg}")
def _fork_main( def _fork_main(
self, self,
@ -541,8 +547,7 @@ class Actor:
if self._parent_chan: if self._parent_chan:
try: try:
# internal error so ship to parent without cid # internal error so ship to parent without cid
await self._parent_chan.send( await self._parent_chan.send(pack_error(err))
pack_error(err))
except trio.ClosedResourceError: except trio.ClosedResourceError:
log.error( log.error(
f"Failed to ship error to parent " f"Failed to ship error to parent "
@ -627,21 +632,47 @@ class Actor:
self.cancel_server() self.cancel_server()
self._root_nursery.cancel_scope.cancel() self._root_nursery.cancel_scope.cancel()
async def cancel_task(self, cid, ctx):
"""Cancel a local task.
Note this method will be treated as a streaming funciton
by remote actor-callers due to the declaration of ``ctx``
in the signature (for now).
"""
# right now this is only implicitly called by
# streaming IPC but it should be called
# to cancel any remotely spawned task
chan = ctx.chan
# the ``dict.get()`` ensures the requested task to be cancelled
# was indeed spawned by a request from this channel
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
log.debug(
f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
# if func is self.cancel_task:
# return
scope.cancel()
# wait for _invoke to mark the task complete
await is_complete.wait()
log.debug(
f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
async def cancel_rpc_tasks(self) -> None: async def cancel_rpc_tasks(self) -> None:
"""Cancel all existing RPC responder tasks using the cancel scope """Cancel all existing RPC responder tasks using the cancel scope
registered for each. registered for each.
""" """
tasks = self._rpc_tasks tasks = self._rpc_tasks
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}") log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
for chan, cids2scopes in tasks.items(): for (chan, cid) in tasks.copy():
log.debug(f"Cancelling all tasks for {chan.uid}") # TODO: this should really done in a nursery batch
for cid, (scope, func) in cids2scopes.items(): await self.cancel_task(cid, Context(chan, cid))
log.debug(f"Cancelling task for {func}") # if tasks:
scope.cancel() log.info(
if tasks: f"Waiting for remaining rpc tasks to complete {tasks}")
log.info( await self._no_more_rpc_tasks.wait()
f"Waiting for remaining rpc tasks to complete {tasks}")
await self._no_more_rpc_tasks.wait()
def cancel_server(self) -> None: def cancel_server(self) -> None:
"""Cancel the internal channel server nursery thereby """Cancel the internal channel server nursery thereby
@ -810,7 +841,9 @@ async def find_actor(
sockaddr = await arb_portal.run('self', 'find_actor', name=name) sockaddr = await arb_portal.run('self', 'find_actor', name=name)
# TODO: return portals to all available actors - for now just # TODO: return portals to all available actors - for now just
# the last one that registered # the last one that registered
if sockaddr: if name == 'arbiter' and actor.is_arbiter:
raise RuntimeError("The current actor is the arbiter")
elif sockaddr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(*sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal