Add `Actor.cancel_task()`

Enable cancelling specific tasks from a peer actor such that when
a actor task or the actor itself is cancelled, remotely spawned tasks
can also be cancelled. In much that same way that you'd expect a node
(task) in the `trio` task tree to cancel any subtasks, actors should
be able to cancel any tasks they spawn in separate processes.

To enable this:
- track rpc tasks in a flat dict keyed by (chan, cid)
- store a `is_complete` event to enable waiting on specific
  tasks to complete
- allow for shielding the msg loop inside an internal cancel scope
  if requested by the caller; there was an issue with `open_portal()`
  where the channel would be torn down because the current task was
  cancelled but we still need messaging to continue until the portal
  block is exited
- throw an error if the arbiter tries to find itself for now
contexts
Tyler Goodlet 2019-01-21 00:16:20 -05:00
parent 251ee177fa
commit 03e00886da
1 changed files with 139 additions and 106 deletions

View File

@ -140,19 +140,15 @@ async def _invoke(
task_status.started(err) task_status.started(err)
finally: finally:
# RPC task bookeeping # RPC task bookeeping
tasks = actor._rpc_tasks.get(chan, None)
if tasks:
try: try:
scope, func = tasks.pop(cid) scope, func, is_complete = actor._rpc_tasks.pop((chan, cid))
except ValueError: is_complete.set()
except KeyError:
# If we're cancelled before the task returns then the # If we're cancelled before the task returns then the
# cancel scope will not have been inserted yet # cancel scope will not have been inserted yet
log.warn( log.warn(
f"Task {func} was likely cancelled before it was started") f"Task {func} was likely cancelled before it was started")
if not tasks:
actor._rpc_tasks.pop(chan, None)
if not actor._rpc_tasks: if not actor._rpc_tasks:
log.info(f"All RPC tasks have completed") log.info(f"All RPC tasks have completed")
actor._no_more_rpc_tasks.set() actor._no_more_rpc_tasks.set()
@ -197,9 +193,10 @@ class Actor:
self._no_more_rpc_tasks = trio.Event() self._no_more_rpc_tasks = trio.Event()
self._no_more_rpc_tasks.set() self._no_more_rpc_tasks.set()
# (chan, cid) -> (cancel_scope, func)
self._rpc_tasks: Dict[ self._rpc_tasks: Dict[
Channel, Tuple[Channel, str],
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]] Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event]
] = {} ] = {}
# map {uids -> {callids -> waiter queues}} # map {uids -> {callids -> waiter queues}}
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
@ -268,7 +265,7 @@ class Actor:
log.warning( log.warning(
f"already have channel(s) for {uid}:{chans}?" f"already have channel(s) for {uid}:{chans}?"
) )
log.debug(f"Registered {chan} for {uid}") log.trace(f"Registered {chan} for {uid}")
# append new channel # append new channel
self._peers[uid].append(chan) self._peers[uid].append(chan)
@ -295,8 +292,9 @@ class Actor:
if chan.connected(): if chan.connected():
log.debug(f"Disconnecting channel {chan}") log.debug(f"Disconnecting channel {chan}")
try: try:
# send our msg loop terminate sentinel
await chan.send(None) await chan.send(None)
await chan.aclose() # await chan.aclose()
except trio.BrokenResourceError: except trio.BrokenResourceError:
log.exception( log.exception(
f"Channel for {chan.uid} was already zonked..") f"Channel for {chan.uid} was already zonked..")
@ -334,7 +332,10 @@ class Actor:
return cid, q return cid, q
async def _process_messages( async def _process_messages(
self, chan: Channel, treat_as_gen: bool = False self, chan: Channel,
treat_as_gen: bool = False,
shield: bool = False,
task_status=trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
"""Process messages for the channel async-RPC style. """Process messages for the channel async-RPC style.
@ -342,33 +343,30 @@ class Actor:
""" """
# TODO: once https://github.com/python-trio/trio/issues/467 gets # TODO: once https://github.com/python-trio/trio/issues/467 gets
# worked out we'll likely want to use that! # worked out we'll likely want to use that!
msg = None
log.debug(f"Entering msg loop for {chan} from {chan.uid}") log.debug(f"Entering msg loop for {chan} from {chan.uid}")
try: try:
# internal scope allows for keeping this message
# loop running despite the current task having been
# cancelled (eg. `open_portal()` may call this method from
# a locally spawned task)
with trio.open_cancel_scope(shield=shield) as cs:
task_status.started(cs)
async for msg in chan: async for msg in chan:
if msg is None: # terminate sentinel if msg is None: # loop terminate sentinel
log.debug( log.debug(
f"Cancelling all tasks for {chan} from {chan.uid}") f"Cancelling all tasks for {chan} from {chan.uid}")
for cid, (scope, func) in self._rpc_tasks.pop( for (channel, cid) in self._rpc_tasks:
chan, {} if channel is chan:
).items(): self.cancel_task(cid, Context(channel, cid))
scope.cancel()
log.debug( log.debug(
f"Msg loop signalled to terminate for" f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}") f" {chan} from {chan.uid}")
break break
log.debug(f"Received msg {msg} from {chan.uid}") log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid') cid = msg.get('cid')
if cid: if cid:
cancel = msg.get('cancel')
if cancel:
# right now this is only implicitly used by
# async generator IPC
scope, func = self._rpc_tasks[chan][cid]
log.debug(
f"Received cancel request for task {cid}"
f" from {chan.uid}")
scope.cancel()
else:
# deliver response to local caller/waiter # deliver response to local caller/waiter
await self._push_result(chan.uid, cid, msg) await self._push_result(chan.uid, cid, msg)
log.debug( log.debug(
@ -414,14 +412,16 @@ class Actor:
# deadlock and other weird behaviour) # deadlock and other weird behaviour)
if func != self.cancel: if func != self.cancel:
if isinstance(cs, Exception): if isinstance(cs, Exception):
log.warn(f"Task for RPC func {func} failed with {cs}") log.warn(f"Task for RPC func {func} failed with"
f"{cs}")
else: else:
# mark that we have ongoing rpc tasks # mark that we have ongoing rpc tasks
self._no_more_rpc_tasks.clear() self._no_more_rpc_tasks.clear()
log.info(f"RPC func is {func}") log.info(f"RPC func is {func}")
# store cancel scope such that the rpc task can be # store cancel scope such that the rpc task can be
# cancelled gracefully if requested # cancelled gracefully if requested
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func) self._rpc_tasks[(chan, cid)] = (
cs, func, trio.Event())
log.debug( log.debug(
f"Waiting on next msg for {chan} from {chan.uid}") f"Waiting on next msg for {chan} from {chan.uid}")
else: else:
@ -439,8 +439,14 @@ class Actor:
raise raise
# if this is the `MainProcess` we expect the error broadcasting # if this is the `MainProcess` we expect the error broadcasting
# above to trigger an error at consuming portal "checkpoints" # above to trigger an error at consuming portal "checkpoints"
except trio.Cancelled:
# debugging only
log.debug("Msg loop was cancelled")
raise
finally: finally:
log.debug(f"Exiting msg loop for {chan} from {chan.uid}") log.debug(
f"Exiting msg loop for {chan} from {chan.uid} "
f"with last msg:\n{msg}")
def _fork_main( def _fork_main(
self, self,
@ -541,8 +547,7 @@ class Actor:
if self._parent_chan: if self._parent_chan:
try: try:
# internal error so ship to parent without cid # internal error so ship to parent without cid
await self._parent_chan.send( await self._parent_chan.send(pack_error(err))
pack_error(err))
except trio.ClosedResourceError: except trio.ClosedResourceError:
log.error( log.error(
f"Failed to ship error to parent " f"Failed to ship error to parent "
@ -627,18 +632,44 @@ class Actor:
self.cancel_server() self.cancel_server()
self._root_nursery.cancel_scope.cancel() self._root_nursery.cancel_scope.cancel()
async def cancel_task(self, cid, ctx):
"""Cancel a local task.
Note this method will be treated as a streaming funciton
by remote actor-callers due to the declaration of ``ctx``
in the signature (for now).
"""
# right now this is only implicitly called by
# streaming IPC but it should be called
# to cancel any remotely spawned task
chan = ctx.chan
# the ``dict.get()`` ensures the requested task to be cancelled
# was indeed spawned by a request from this channel
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
log.debug(
f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
# if func is self.cancel_task:
# return
scope.cancel()
# wait for _invoke to mark the task complete
await is_complete.wait()
log.debug(
f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
async def cancel_rpc_tasks(self) -> None: async def cancel_rpc_tasks(self) -> None:
"""Cancel all existing RPC responder tasks using the cancel scope """Cancel all existing RPC responder tasks using the cancel scope
registered for each. registered for each.
""" """
tasks = self._rpc_tasks tasks = self._rpc_tasks
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ") log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
for chan, cids2scopes in tasks.items(): for (chan, cid) in tasks.copy():
log.debug(f"Cancelling all tasks for {chan.uid}") # TODO: this should really done in a nursery batch
for cid, (scope, func) in cids2scopes.items(): await self.cancel_task(cid, Context(chan, cid))
log.debug(f"Cancelling task for {func}") # if tasks:
scope.cancel()
if tasks:
log.info( log.info(
f"Waiting for remaining rpc tasks to complete {tasks}") f"Waiting for remaining rpc tasks to complete {tasks}")
await self._no_more_rpc_tasks.wait() await self._no_more_rpc_tasks.wait()
@ -810,7 +841,9 @@ async def find_actor(
sockaddr = await arb_portal.run('self', 'find_actor', name=name) sockaddr = await arb_portal.run('self', 'find_actor', name=name)
# TODO: return portals to all available actors - for now just # TODO: return portals to all available actors - for now just
# the last one that registered # the last one that registered
if sockaddr: if name == 'arbiter' and actor.is_arbiter:
raise RuntimeError("The current actor is the arbiter")
elif sockaddr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(*sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal