Add `Actor.cancel_task()`
Enable cancelling specific tasks from a peer actor such that when a actor task or the actor itself is cancelled, remotely spawned tasks can also be cancelled. In much that same way that you'd expect a node (task) in the `trio` task tree to cancel any subtasks, actors should be able to cancel any tasks they spawn in separate processes. To enable this: - track rpc tasks in a flat dict keyed by (chan, cid) - store a `is_complete` event to enable waiting on specific tasks to complete - allow for shielding the msg loop inside an internal cancel scope if requested by the caller; there was an issue with `open_portal()` where the channel would be torn down because the current task was cancelled but we still need messaging to continue until the portal block is exited - throw an error if the arbiter tries to find itself for nowcontexts
parent
251ee177fa
commit
03e00886da
|
@ -140,19 +140,15 @@ async def _invoke(
|
|||
task_status.started(err)
|
||||
finally:
|
||||
# RPC task bookeeping
|
||||
tasks = actor._rpc_tasks.get(chan, None)
|
||||
if tasks:
|
||||
try:
|
||||
scope, func = tasks.pop(cid)
|
||||
except ValueError:
|
||||
scope, func, is_complete = actor._rpc_tasks.pop((chan, cid))
|
||||
is_complete.set()
|
||||
except KeyError:
|
||||
# If we're cancelled before the task returns then the
|
||||
# cancel scope will not have been inserted yet
|
||||
log.warn(
|
||||
f"Task {func} was likely cancelled before it was started")
|
||||
|
||||
if not tasks:
|
||||
actor._rpc_tasks.pop(chan, None)
|
||||
|
||||
if not actor._rpc_tasks:
|
||||
log.info(f"All RPC tasks have completed")
|
||||
actor._no_more_rpc_tasks.set()
|
||||
|
@ -197,9 +193,10 @@ class Actor:
|
|||
|
||||
self._no_more_rpc_tasks = trio.Event()
|
||||
self._no_more_rpc_tasks.set()
|
||||
# (chan, cid) -> (cancel_scope, func)
|
||||
self._rpc_tasks: Dict[
|
||||
Channel,
|
||||
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
|
||||
Tuple[Channel, str],
|
||||
Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event]
|
||||
] = {}
|
||||
# map {uids -> {callids -> waiter queues}}
|
||||
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
||||
|
@ -268,7 +265,7 @@ class Actor:
|
|||
log.warning(
|
||||
f"already have channel(s) for {uid}:{chans}?"
|
||||
)
|
||||
log.debug(f"Registered {chan} for {uid}")
|
||||
log.trace(f"Registered {chan} for {uid}")
|
||||
# append new channel
|
||||
self._peers[uid].append(chan)
|
||||
|
||||
|
@ -295,8 +292,9 @@ class Actor:
|
|||
if chan.connected():
|
||||
log.debug(f"Disconnecting channel {chan}")
|
||||
try:
|
||||
# send our msg loop terminate sentinel
|
||||
await chan.send(None)
|
||||
await chan.aclose()
|
||||
# await chan.aclose()
|
||||
except trio.BrokenResourceError:
|
||||
log.exception(
|
||||
f"Channel for {chan.uid} was already zonked..")
|
||||
|
@ -334,7 +332,10 @@ class Actor:
|
|||
return cid, q
|
||||
|
||||
async def _process_messages(
|
||||
self, chan: Channel, treat_as_gen: bool = False
|
||||
self, chan: Channel,
|
||||
treat_as_gen: bool = False,
|
||||
shield: bool = False,
|
||||
task_status=trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
"""Process messages for the channel async-RPC style.
|
||||
|
||||
|
@ -342,33 +343,30 @@ class Actor:
|
|||
"""
|
||||
# TODO: once https://github.com/python-trio/trio/issues/467 gets
|
||||
# worked out we'll likely want to use that!
|
||||
msg = None
|
||||
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
|
||||
try:
|
||||
# internal scope allows for keeping this message
|
||||
# loop running despite the current task having been
|
||||
# cancelled (eg. `open_portal()` may call this method from
|
||||
# a locally spawned task)
|
||||
with trio.open_cancel_scope(shield=shield) as cs:
|
||||
task_status.started(cs)
|
||||
async for msg in chan:
|
||||
if msg is None: # terminate sentinel
|
||||
if msg is None: # loop terminate sentinel
|
||||
log.debug(
|
||||
f"Cancelling all tasks for {chan} from {chan.uid}")
|
||||
for cid, (scope, func) in self._rpc_tasks.pop(
|
||||
chan, {}
|
||||
).items():
|
||||
scope.cancel()
|
||||
for (channel, cid) in self._rpc_tasks:
|
||||
if channel is chan:
|
||||
self.cancel_task(cid, Context(channel, cid))
|
||||
log.debug(
|
||||
f"Msg loop signalled to terminate for"
|
||||
f" {chan} from {chan.uid}")
|
||||
break
|
||||
|
||||
log.debug(f"Received msg {msg} from {chan.uid}")
|
||||
cid = msg.get('cid')
|
||||
if cid:
|
||||
cancel = msg.get('cancel')
|
||||
if cancel:
|
||||
# right now this is only implicitly used by
|
||||
# async generator IPC
|
||||
scope, func = self._rpc_tasks[chan][cid]
|
||||
log.debug(
|
||||
f"Received cancel request for task {cid}"
|
||||
f" from {chan.uid}")
|
||||
scope.cancel()
|
||||
else:
|
||||
# deliver response to local caller/waiter
|
||||
await self._push_result(chan.uid, cid, msg)
|
||||
log.debug(
|
||||
|
@ -414,14 +412,16 @@ class Actor:
|
|||
# deadlock and other weird behaviour)
|
||||
if func != self.cancel:
|
||||
if isinstance(cs, Exception):
|
||||
log.warn(f"Task for RPC func {func} failed with {cs}")
|
||||
log.warn(f"Task for RPC func {func} failed with"
|
||||
f"{cs}")
|
||||
else:
|
||||
# mark that we have ongoing rpc tasks
|
||||
self._no_more_rpc_tasks.clear()
|
||||
log.info(f"RPC func is {func}")
|
||||
# store cancel scope such that the rpc task can be
|
||||
# cancelled gracefully if requested
|
||||
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
|
||||
self._rpc_tasks[(chan, cid)] = (
|
||||
cs, func, trio.Event())
|
||||
log.debug(
|
||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||
else:
|
||||
|
@ -439,8 +439,14 @@ class Actor:
|
|||
raise
|
||||
# if this is the `MainProcess` we expect the error broadcasting
|
||||
# above to trigger an error at consuming portal "checkpoints"
|
||||
except trio.Cancelled:
|
||||
# debugging only
|
||||
log.debug("Msg loop was cancelled")
|
||||
raise
|
||||
finally:
|
||||
log.debug(f"Exiting msg loop for {chan} from {chan.uid}")
|
||||
log.debug(
|
||||
f"Exiting msg loop for {chan} from {chan.uid} "
|
||||
f"with last msg:\n{msg}")
|
||||
|
||||
def _fork_main(
|
||||
self,
|
||||
|
@ -541,8 +547,7 @@ class Actor:
|
|||
if self._parent_chan:
|
||||
try:
|
||||
# internal error so ship to parent without cid
|
||||
await self._parent_chan.send(
|
||||
pack_error(err))
|
||||
await self._parent_chan.send(pack_error(err))
|
||||
except trio.ClosedResourceError:
|
||||
log.error(
|
||||
f"Failed to ship error to parent "
|
||||
|
@ -627,18 +632,44 @@ class Actor:
|
|||
self.cancel_server()
|
||||
self._root_nursery.cancel_scope.cancel()
|
||||
|
||||
async def cancel_task(self, cid, ctx):
|
||||
"""Cancel a local task.
|
||||
|
||||
Note this method will be treated as a streaming funciton
|
||||
by remote actor-callers due to the declaration of ``ctx``
|
||||
in the signature (for now).
|
||||
"""
|
||||
# right now this is only implicitly called by
|
||||
# streaming IPC but it should be called
|
||||
# to cancel any remotely spawned task
|
||||
chan = ctx.chan
|
||||
# the ``dict.get()`` ensures the requested task to be cancelled
|
||||
# was indeed spawned by a request from this channel
|
||||
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
|
||||
log.debug(
|
||||
f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
|
||||
f"peer: {chan.uid}\n")
|
||||
|
||||
# if func is self.cancel_task:
|
||||
# return
|
||||
|
||||
scope.cancel()
|
||||
# wait for _invoke to mark the task complete
|
||||
await is_complete.wait()
|
||||
log.debug(
|
||||
f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n"
|
||||
f"peer: {chan.uid}\n")
|
||||
|
||||
async def cancel_rpc_tasks(self) -> None:
|
||||
"""Cancel all existing RPC responder tasks using the cancel scope
|
||||
registered for each.
|
||||
"""
|
||||
tasks = self._rpc_tasks
|
||||
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
|
||||
for chan, cids2scopes in tasks.items():
|
||||
log.debug(f"Cancelling all tasks for {chan.uid}")
|
||||
for cid, (scope, func) in cids2scopes.items():
|
||||
log.debug(f"Cancelling task for {func}")
|
||||
scope.cancel()
|
||||
if tasks:
|
||||
for (chan, cid) in tasks.copy():
|
||||
# TODO: this should really done in a nursery batch
|
||||
await self.cancel_task(cid, Context(chan, cid))
|
||||
# if tasks:
|
||||
log.info(
|
||||
f"Waiting for remaining rpc tasks to complete {tasks}")
|
||||
await self._no_more_rpc_tasks.wait()
|
||||
|
@ -810,7 +841,9 @@ async def find_actor(
|
|||
sockaddr = await arb_portal.run('self', 'find_actor', name=name)
|
||||
# TODO: return portals to all available actors - for now just
|
||||
# the last one that registered
|
||||
if sockaddr:
|
||||
if name == 'arbiter' and actor.is_arbiter:
|
||||
raise RuntimeError("The current actor is the arbiter")
|
||||
elif sockaddr:
|
||||
async with _connect_chan(*sockaddr) as chan:
|
||||
async with open_portal(chan) as portal:
|
||||
yield portal
|
||||
|
|
Loading…
Reference in New Issue