Add `Actor.cancel_task()`

Enable cancelling specific tasks from a peer actor such that when
a actor task or the actor itself is cancelled, remotely spawned tasks
can also be cancelled. In much that same way that you'd expect a node
(task) in the `trio` task tree to cancel any subtasks, actors should
be able to cancel any tasks they spawn in separate processes.

To enable this:
- track rpc tasks in a flat dict keyed by (chan, cid)
- store a `is_complete` event to enable waiting on specific
  tasks to complete
- allow for shielding the msg loop inside an internal cancel scope
  if requested by the caller; there was an issue with `open_portal()`
  where the channel would be torn down because the current task was
  cancelled but we still need messaging to continue until the portal
  block is exited
- throw an error if the arbiter tries to find itself for now
contexts
Tyler Goodlet 2019-01-21 00:16:20 -05:00
parent 251ee177fa
commit 03e00886da
1 changed files with 139 additions and 106 deletions

View File

@ -140,18 +140,14 @@ async def _invoke(
task_status.started(err)
finally:
# RPC task bookeeping
tasks = actor._rpc_tasks.get(chan, None)
if tasks:
try:
scope, func = tasks.pop(cid)
except ValueError:
# If we're cancelled before the task returns then the
# cancel scope will not have been inserted yet
log.warn(
f"Task {func} was likely cancelled before it was started")
if not tasks:
actor._rpc_tasks.pop(chan, None)
try:
scope, func, is_complete = actor._rpc_tasks.pop((chan, cid))
is_complete.set()
except KeyError:
# If we're cancelled before the task returns then the
# cancel scope will not have been inserted yet
log.warn(
f"Task {func} was likely cancelled before it was started")
if not actor._rpc_tasks:
log.info(f"All RPC tasks have completed")
@ -197,9 +193,10 @@ class Actor:
self._no_more_rpc_tasks = trio.Event()
self._no_more_rpc_tasks.set()
# (chan, cid) -> (cancel_scope, func)
self._rpc_tasks: Dict[
Channel,
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
Tuple[Channel, str],
Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event]
] = {}
# map {uids -> {callids -> waiter queues}}
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
@ -268,7 +265,7 @@ class Actor:
log.warning(
f"already have channel(s) for {uid}:{chans}?"
)
log.debug(f"Registered {chan} for {uid}")
log.trace(f"Registered {chan} for {uid}")
# append new channel
self._peers[uid].append(chan)
@ -295,8 +292,9 @@ class Actor:
if chan.connected():
log.debug(f"Disconnecting channel {chan}")
try:
# send our msg loop terminate sentinel
await chan.send(None)
await chan.aclose()
# await chan.aclose()
except trio.BrokenResourceError:
log.exception(
f"Channel for {chan.uid} was already zonked..")
@ -334,7 +332,10 @@ class Actor:
return cid, q
async def _process_messages(
self, chan: Channel, treat_as_gen: bool = False
self, chan: Channel,
treat_as_gen: bool = False,
shield: bool = False,
task_status=trio.TASK_STATUS_IGNORED,
) -> None:
"""Process messages for the channel async-RPC style.
@ -342,91 +343,90 @@ class Actor:
"""
# TODO: once https://github.com/python-trio/trio/issues/467 gets
# worked out we'll likely want to use that!
msg = None
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
try:
async for msg in chan:
if msg is None: # terminate sentinel
log.debug(
f"Cancelling all tasks for {chan} from {chan.uid}")
for cid, (scope, func) in self._rpc_tasks.pop(
chan, {}
).items():
scope.cancel()
log.debug(
f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}")
break
log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid')
if cid:
cancel = msg.get('cancel')
if cancel:
# right now this is only implicitly used by
# async generator IPC
scope, func = self._rpc_tasks[chan][cid]
# internal scope allows for keeping this message
# loop running despite the current task having been
# cancelled (eg. `open_portal()` may call this method from
# a locally spawned task)
with trio.open_cancel_scope(shield=shield) as cs:
task_status.started(cs)
async for msg in chan:
if msg is None: # loop terminate sentinel
log.debug(
f"Received cancel request for task {cid}"
f" from {chan.uid}")
scope.cancel()
else:
f"Cancelling all tasks for {chan} from {chan.uid}")
for (channel, cid) in self._rpc_tasks:
if channel is chan:
self.cancel_task(cid, Context(channel, cid))
log.debug(
f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}")
break
log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid')
if cid:
# deliver response to local caller/waiter
await self._push_result(chan.uid, cid, msg)
log.debug(
f"Waiting on next msg for {chan} from {chan.uid}")
continue
# process command request
try:
ns, funcname, kwargs, actorid, cid = msg['cmd']
except KeyError:
# This is the non-rpc error case, that is, an
# error **not** raised inside a call to ``_invoke()``
# (i.e. no cid was provided in the msg - see above).
# Push this error to all local channel consumers
# (normally portals) by marking the channel as errored
assert chan.uid
exc = unpack_error(msg, chan=chan)
chan._exc = exc
raise exc
log.debug(
f"Processing request from {actorid}\n"
f"{ns}.{funcname}({kwargs})")
if ns == 'self':
func = getattr(self, funcname)
else:
# complain to client about restricted modules
try:
func = self._get_rpc_func(ns, funcname)
except (ModuleNotExposed, AttributeError) as err:
err_msg = pack_error(err)
err_msg['cid'] = cid
await chan.send(err_msg)
continue
# spin up a task for the requested function
log.debug(f"Spawning task for {func}")
cs = await self._root_nursery.start(
_invoke, self, cid, chan, func, kwargs,
name=funcname
)
# never allow cancelling cancel requests (results in
# deadlock and other weird behaviour)
if func != self.cancel:
if isinstance(cs, Exception):
log.warn(f"Task for RPC func {func} failed with {cs}")
# process command request
try:
ns, funcname, kwargs, actorid, cid = msg['cmd']
except KeyError:
# This is the non-rpc error case, that is, an
# error **not** raised inside a call to ``_invoke()``
# (i.e. no cid was provided in the msg - see above).
# Push this error to all local channel consumers
# (normally portals) by marking the channel as errored
assert chan.uid
exc = unpack_error(msg, chan=chan)
chan._exc = exc
raise exc
log.debug(
f"Processing request from {actorid}\n"
f"{ns}.{funcname}({kwargs})")
if ns == 'self':
func = getattr(self, funcname)
else:
# mark that we have ongoing rpc tasks
self._no_more_rpc_tasks.clear()
log.info(f"RPC func is {func}")
# store cancel scope such that the rpc task can be
# cancelled gracefully if requested
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
log.debug(
f"Waiting on next msg for {chan} from {chan.uid}")
else:
# channel disconnect
log.debug(f"{chan} from {chan.uid} disconnected")
# complain to client about restricted modules
try:
func = self._get_rpc_func(ns, funcname)
except (ModuleNotExposed, AttributeError) as err:
err_msg = pack_error(err)
err_msg['cid'] = cid
await chan.send(err_msg)
continue
# spin up a task for the requested function
log.debug(f"Spawning task for {func}")
cs = await self._root_nursery.start(
_invoke, self, cid, chan, func, kwargs,
name=funcname
)
# never allow cancelling cancel requests (results in
# deadlock and other weird behaviour)
if func != self.cancel:
if isinstance(cs, Exception):
log.warn(f"Task for RPC func {func} failed with"
f"{cs}")
else:
# mark that we have ongoing rpc tasks
self._no_more_rpc_tasks.clear()
log.info(f"RPC func is {func}")
# store cancel scope such that the rpc task can be
# cancelled gracefully if requested
self._rpc_tasks[(chan, cid)] = (
cs, func, trio.Event())
log.debug(
f"Waiting on next msg for {chan} from {chan.uid}")
else:
# channel disconnect
log.debug(f"{chan} from {chan.uid} disconnected")
except trio.ClosedResourceError:
log.error(f"{chan} form {chan.uid} broke")
@ -439,8 +439,14 @@ class Actor:
raise
# if this is the `MainProcess` we expect the error broadcasting
# above to trigger an error at consuming portal "checkpoints"
except trio.Cancelled:
# debugging only
log.debug("Msg loop was cancelled")
raise
finally:
log.debug(f"Exiting msg loop for {chan} from {chan.uid}")
log.debug(
f"Exiting msg loop for {chan} from {chan.uid} "
f"with last msg:\n{msg}")
def _fork_main(
self,
@ -541,8 +547,7 @@ class Actor:
if self._parent_chan:
try:
# internal error so ship to parent without cid
await self._parent_chan.send(
pack_error(err))
await self._parent_chan.send(pack_error(err))
except trio.ClosedResourceError:
log.error(
f"Failed to ship error to parent "
@ -627,21 +632,47 @@ class Actor:
self.cancel_server()
self._root_nursery.cancel_scope.cancel()
async def cancel_task(self, cid, ctx):
"""Cancel a local task.
Note this method will be treated as a streaming funciton
by remote actor-callers due to the declaration of ``ctx``
in the signature (for now).
"""
# right now this is only implicitly called by
# streaming IPC but it should be called
# to cancel any remotely spawned task
chan = ctx.chan
# the ``dict.get()`` ensures the requested task to be cancelled
# was indeed spawned by a request from this channel
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
log.debug(
f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
# if func is self.cancel_task:
# return
scope.cancel()
# wait for _invoke to mark the task complete
await is_complete.wait()
log.debug(
f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n"
f"peer: {chan.uid}\n")
async def cancel_rpc_tasks(self) -> None:
"""Cancel all existing RPC responder tasks using the cancel scope
registered for each.
"""
tasks = self._rpc_tasks
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
for chan, cids2scopes in tasks.items():
log.debug(f"Cancelling all tasks for {chan.uid}")
for cid, (scope, func) in cids2scopes.items():
log.debug(f"Cancelling task for {func}")
scope.cancel()
if tasks:
log.info(
f"Waiting for remaining rpc tasks to complete {tasks}")
await self._no_more_rpc_tasks.wait()
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
for (chan, cid) in tasks.copy():
# TODO: this should really done in a nursery batch
await self.cancel_task(cid, Context(chan, cid))
# if tasks:
log.info(
f"Waiting for remaining rpc tasks to complete {tasks}")
await self._no_more_rpc_tasks.wait()
def cancel_server(self) -> None:
"""Cancel the internal channel server nursery thereby
@ -810,7 +841,9 @@ async def find_actor(
sockaddr = await arb_portal.run('self', 'find_actor', name=name)
# TODO: return portals to all available actors - for now just
# the last one that registered
if sockaddr:
if name == 'arbiter' and actor.is_arbiter:
raise RuntimeError("The current actor is the arbiter")
elif sockaddr:
async with _connect_chan(*sockaddr) as chan:
async with open_portal(chan) as portal:
yield portal