Add `Actor.cancel_task()`
Enable cancelling specific tasks from a peer actor such that when a actor task or the actor itself is cancelled, remotely spawned tasks can also be cancelled. In much that same way that you'd expect a node (task) in the `trio` task tree to cancel any subtasks, actors should be able to cancel any tasks they spawn in separate processes. To enable this: - track rpc tasks in a flat dict keyed by (chan, cid) - store a `is_complete` event to enable waiting on specific tasks to complete - allow for shielding the msg loop inside an internal cancel scope if requested by the caller; there was an issue with `open_portal()` where the channel would be torn down because the current task was cancelled but we still need messaging to continue until the portal block is exited - throw an error if the arbiter tries to find itself for nowcontexts
parent
251ee177fa
commit
03e00886da
|
@ -140,18 +140,14 @@ async def _invoke(
|
||||||
task_status.started(err)
|
task_status.started(err)
|
||||||
finally:
|
finally:
|
||||||
# RPC task bookeeping
|
# RPC task bookeeping
|
||||||
tasks = actor._rpc_tasks.get(chan, None)
|
try:
|
||||||
if tasks:
|
scope, func, is_complete = actor._rpc_tasks.pop((chan, cid))
|
||||||
try:
|
is_complete.set()
|
||||||
scope, func = tasks.pop(cid)
|
except KeyError:
|
||||||
except ValueError:
|
# If we're cancelled before the task returns then the
|
||||||
# If we're cancelled before the task returns then the
|
# cancel scope will not have been inserted yet
|
||||||
# cancel scope will not have been inserted yet
|
log.warn(
|
||||||
log.warn(
|
f"Task {func} was likely cancelled before it was started")
|
||||||
f"Task {func} was likely cancelled before it was started")
|
|
||||||
|
|
||||||
if not tasks:
|
|
||||||
actor._rpc_tasks.pop(chan, None)
|
|
||||||
|
|
||||||
if not actor._rpc_tasks:
|
if not actor._rpc_tasks:
|
||||||
log.info(f"All RPC tasks have completed")
|
log.info(f"All RPC tasks have completed")
|
||||||
|
@ -197,9 +193,10 @@ class Actor:
|
||||||
|
|
||||||
self._no_more_rpc_tasks = trio.Event()
|
self._no_more_rpc_tasks = trio.Event()
|
||||||
self._no_more_rpc_tasks.set()
|
self._no_more_rpc_tasks.set()
|
||||||
|
# (chan, cid) -> (cancel_scope, func)
|
||||||
self._rpc_tasks: Dict[
|
self._rpc_tasks: Dict[
|
||||||
Channel,
|
Tuple[Channel, str],
|
||||||
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
|
Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event]
|
||||||
] = {}
|
] = {}
|
||||||
# map {uids -> {callids -> waiter queues}}
|
# map {uids -> {callids -> waiter queues}}
|
||||||
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
||||||
|
@ -268,7 +265,7 @@ class Actor:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"already have channel(s) for {uid}:{chans}?"
|
f"already have channel(s) for {uid}:{chans}?"
|
||||||
)
|
)
|
||||||
log.debug(f"Registered {chan} for {uid}")
|
log.trace(f"Registered {chan} for {uid}")
|
||||||
# append new channel
|
# append new channel
|
||||||
self._peers[uid].append(chan)
|
self._peers[uid].append(chan)
|
||||||
|
|
||||||
|
@ -295,8 +292,9 @@ class Actor:
|
||||||
if chan.connected():
|
if chan.connected():
|
||||||
log.debug(f"Disconnecting channel {chan}")
|
log.debug(f"Disconnecting channel {chan}")
|
||||||
try:
|
try:
|
||||||
|
# send our msg loop terminate sentinel
|
||||||
await chan.send(None)
|
await chan.send(None)
|
||||||
await chan.aclose()
|
# await chan.aclose()
|
||||||
except trio.BrokenResourceError:
|
except trio.BrokenResourceError:
|
||||||
log.exception(
|
log.exception(
|
||||||
f"Channel for {chan.uid} was already zonked..")
|
f"Channel for {chan.uid} was already zonked..")
|
||||||
|
@ -334,7 +332,10 @@ class Actor:
|
||||||
return cid, q
|
return cid, q
|
||||||
|
|
||||||
async def _process_messages(
|
async def _process_messages(
|
||||||
self, chan: Channel, treat_as_gen: bool = False
|
self, chan: Channel,
|
||||||
|
treat_as_gen: bool = False,
|
||||||
|
shield: bool = False,
|
||||||
|
task_status=trio.TASK_STATUS_IGNORED,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process messages for the channel async-RPC style.
|
"""Process messages for the channel async-RPC style.
|
||||||
|
|
||||||
|
@ -342,91 +343,90 @@ class Actor:
|
||||||
"""
|
"""
|
||||||
# TODO: once https://github.com/python-trio/trio/issues/467 gets
|
# TODO: once https://github.com/python-trio/trio/issues/467 gets
|
||||||
# worked out we'll likely want to use that!
|
# worked out we'll likely want to use that!
|
||||||
|
msg = None
|
||||||
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
|
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
|
||||||
try:
|
try:
|
||||||
async for msg in chan:
|
# internal scope allows for keeping this message
|
||||||
if msg is None: # terminate sentinel
|
# loop running despite the current task having been
|
||||||
log.debug(
|
# cancelled (eg. `open_portal()` may call this method from
|
||||||
f"Cancelling all tasks for {chan} from {chan.uid}")
|
# a locally spawned task)
|
||||||
for cid, (scope, func) in self._rpc_tasks.pop(
|
with trio.open_cancel_scope(shield=shield) as cs:
|
||||||
chan, {}
|
task_status.started(cs)
|
||||||
).items():
|
async for msg in chan:
|
||||||
scope.cancel()
|
if msg is None: # loop terminate sentinel
|
||||||
log.debug(
|
|
||||||
f"Msg loop signalled to terminate for"
|
|
||||||
f" {chan} from {chan.uid}")
|
|
||||||
break
|
|
||||||
log.debug(f"Received msg {msg} from {chan.uid}")
|
|
||||||
cid = msg.get('cid')
|
|
||||||
if cid:
|
|
||||||
cancel = msg.get('cancel')
|
|
||||||
if cancel:
|
|
||||||
# right now this is only implicitly used by
|
|
||||||
# async generator IPC
|
|
||||||
scope, func = self._rpc_tasks[chan][cid]
|
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Received cancel request for task {cid}"
|
f"Cancelling all tasks for {chan} from {chan.uid}")
|
||||||
f" from {chan.uid}")
|
for (channel, cid) in self._rpc_tasks:
|
||||||
scope.cancel()
|
if channel is chan:
|
||||||
else:
|
self.cancel_task(cid, Context(channel, cid))
|
||||||
|
log.debug(
|
||||||
|
f"Msg loop signalled to terminate for"
|
||||||
|
f" {chan} from {chan.uid}")
|
||||||
|
break
|
||||||
|
|
||||||
|
log.debug(f"Received msg {msg} from {chan.uid}")
|
||||||
|
cid = msg.get('cid')
|
||||||
|
if cid:
|
||||||
# deliver response to local caller/waiter
|
# deliver response to local caller/waiter
|
||||||
await self._push_result(chan.uid, cid, msg)
|
await self._push_result(chan.uid, cid, msg)
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||||
continue
|
|
||||||
|
|
||||||
# process command request
|
|
||||||
try:
|
|
||||||
ns, funcname, kwargs, actorid, cid = msg['cmd']
|
|
||||||
except KeyError:
|
|
||||||
# This is the non-rpc error case, that is, an
|
|
||||||
# error **not** raised inside a call to ``_invoke()``
|
|
||||||
# (i.e. no cid was provided in the msg - see above).
|
|
||||||
# Push this error to all local channel consumers
|
|
||||||
# (normally portals) by marking the channel as errored
|
|
||||||
assert chan.uid
|
|
||||||
exc = unpack_error(msg, chan=chan)
|
|
||||||
chan._exc = exc
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
log.debug(
|
|
||||||
f"Processing request from {actorid}\n"
|
|
||||||
f"{ns}.{funcname}({kwargs})")
|
|
||||||
if ns == 'self':
|
|
||||||
func = getattr(self, funcname)
|
|
||||||
else:
|
|
||||||
# complain to client about restricted modules
|
|
||||||
try:
|
|
||||||
func = self._get_rpc_func(ns, funcname)
|
|
||||||
except (ModuleNotExposed, AttributeError) as err:
|
|
||||||
err_msg = pack_error(err)
|
|
||||||
err_msg['cid'] = cid
|
|
||||||
await chan.send(err_msg)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# spin up a task for the requested function
|
# process command request
|
||||||
log.debug(f"Spawning task for {func}")
|
try:
|
||||||
cs = await self._root_nursery.start(
|
ns, funcname, kwargs, actorid, cid = msg['cmd']
|
||||||
_invoke, self, cid, chan, func, kwargs,
|
except KeyError:
|
||||||
name=funcname
|
# This is the non-rpc error case, that is, an
|
||||||
)
|
# error **not** raised inside a call to ``_invoke()``
|
||||||
# never allow cancelling cancel requests (results in
|
# (i.e. no cid was provided in the msg - see above).
|
||||||
# deadlock and other weird behaviour)
|
# Push this error to all local channel consumers
|
||||||
if func != self.cancel:
|
# (normally portals) by marking the channel as errored
|
||||||
if isinstance(cs, Exception):
|
assert chan.uid
|
||||||
log.warn(f"Task for RPC func {func} failed with {cs}")
|
exc = unpack_error(msg, chan=chan)
|
||||||
|
chan._exc = exc
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
f"Processing request from {actorid}\n"
|
||||||
|
f"{ns}.{funcname}({kwargs})")
|
||||||
|
if ns == 'self':
|
||||||
|
func = getattr(self, funcname)
|
||||||
else:
|
else:
|
||||||
# mark that we have ongoing rpc tasks
|
# complain to client about restricted modules
|
||||||
self._no_more_rpc_tasks.clear()
|
try:
|
||||||
log.info(f"RPC func is {func}")
|
func = self._get_rpc_func(ns, funcname)
|
||||||
# store cancel scope such that the rpc task can be
|
except (ModuleNotExposed, AttributeError) as err:
|
||||||
# cancelled gracefully if requested
|
err_msg = pack_error(err)
|
||||||
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
|
err_msg['cid'] = cid
|
||||||
log.debug(
|
await chan.send(err_msg)
|
||||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
continue
|
||||||
else:
|
|
||||||
# channel disconnect
|
# spin up a task for the requested function
|
||||||
log.debug(f"{chan} from {chan.uid} disconnected")
|
log.debug(f"Spawning task for {func}")
|
||||||
|
cs = await self._root_nursery.start(
|
||||||
|
_invoke, self, cid, chan, func, kwargs,
|
||||||
|
name=funcname
|
||||||
|
)
|
||||||
|
# never allow cancelling cancel requests (results in
|
||||||
|
# deadlock and other weird behaviour)
|
||||||
|
if func != self.cancel:
|
||||||
|
if isinstance(cs, Exception):
|
||||||
|
log.warn(f"Task for RPC func {func} failed with"
|
||||||
|
f"{cs}")
|
||||||
|
else:
|
||||||
|
# mark that we have ongoing rpc tasks
|
||||||
|
self._no_more_rpc_tasks.clear()
|
||||||
|
log.info(f"RPC func is {func}")
|
||||||
|
# store cancel scope such that the rpc task can be
|
||||||
|
# cancelled gracefully if requested
|
||||||
|
self._rpc_tasks[(chan, cid)] = (
|
||||||
|
cs, func, trio.Event())
|
||||||
|
log.debug(
|
||||||
|
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||||
|
else:
|
||||||
|
# channel disconnect
|
||||||
|
log.debug(f"{chan} from {chan.uid} disconnected")
|
||||||
|
|
||||||
except trio.ClosedResourceError:
|
except trio.ClosedResourceError:
|
||||||
log.error(f"{chan} form {chan.uid} broke")
|
log.error(f"{chan} form {chan.uid} broke")
|
||||||
|
@ -439,8 +439,14 @@ class Actor:
|
||||||
raise
|
raise
|
||||||
# if this is the `MainProcess` we expect the error broadcasting
|
# if this is the `MainProcess` we expect the error broadcasting
|
||||||
# above to trigger an error at consuming portal "checkpoints"
|
# above to trigger an error at consuming portal "checkpoints"
|
||||||
|
except trio.Cancelled:
|
||||||
|
# debugging only
|
||||||
|
log.debug("Msg loop was cancelled")
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
log.debug(f"Exiting msg loop for {chan} from {chan.uid}")
|
log.debug(
|
||||||
|
f"Exiting msg loop for {chan} from {chan.uid} "
|
||||||
|
f"with last msg:\n{msg}")
|
||||||
|
|
||||||
def _fork_main(
|
def _fork_main(
|
||||||
self,
|
self,
|
||||||
|
@ -541,8 +547,7 @@ class Actor:
|
||||||
if self._parent_chan:
|
if self._parent_chan:
|
||||||
try:
|
try:
|
||||||
# internal error so ship to parent without cid
|
# internal error so ship to parent without cid
|
||||||
await self._parent_chan.send(
|
await self._parent_chan.send(pack_error(err))
|
||||||
pack_error(err))
|
|
||||||
except trio.ClosedResourceError:
|
except trio.ClosedResourceError:
|
||||||
log.error(
|
log.error(
|
||||||
f"Failed to ship error to parent "
|
f"Failed to ship error to parent "
|
||||||
|
@ -627,21 +632,47 @@ class Actor:
|
||||||
self.cancel_server()
|
self.cancel_server()
|
||||||
self._root_nursery.cancel_scope.cancel()
|
self._root_nursery.cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def cancel_task(self, cid, ctx):
|
||||||
|
"""Cancel a local task.
|
||||||
|
|
||||||
|
Note this method will be treated as a streaming funciton
|
||||||
|
by remote actor-callers due to the declaration of ``ctx``
|
||||||
|
in the signature (for now).
|
||||||
|
"""
|
||||||
|
# right now this is only implicitly called by
|
||||||
|
# streaming IPC but it should be called
|
||||||
|
# to cancel any remotely spawned task
|
||||||
|
chan = ctx.chan
|
||||||
|
# the ``dict.get()`` ensures the requested task to be cancelled
|
||||||
|
# was indeed spawned by a request from this channel
|
||||||
|
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
|
||||||
|
log.debug(
|
||||||
|
f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
|
||||||
|
f"peer: {chan.uid}\n")
|
||||||
|
|
||||||
|
# if func is self.cancel_task:
|
||||||
|
# return
|
||||||
|
|
||||||
|
scope.cancel()
|
||||||
|
# wait for _invoke to mark the task complete
|
||||||
|
await is_complete.wait()
|
||||||
|
log.debug(
|
||||||
|
f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n"
|
||||||
|
f"peer: {chan.uid}\n")
|
||||||
|
|
||||||
async def cancel_rpc_tasks(self) -> None:
|
async def cancel_rpc_tasks(self) -> None:
|
||||||
"""Cancel all existing RPC responder tasks using the cancel scope
|
"""Cancel all existing RPC responder tasks using the cancel scope
|
||||||
registered for each.
|
registered for each.
|
||||||
"""
|
"""
|
||||||
tasks = self._rpc_tasks
|
tasks = self._rpc_tasks
|
||||||
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
|
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
|
||||||
for chan, cids2scopes in tasks.items():
|
for (chan, cid) in tasks.copy():
|
||||||
log.debug(f"Cancelling all tasks for {chan.uid}")
|
# TODO: this should really done in a nursery batch
|
||||||
for cid, (scope, func) in cids2scopes.items():
|
await self.cancel_task(cid, Context(chan, cid))
|
||||||
log.debug(f"Cancelling task for {func}")
|
# if tasks:
|
||||||
scope.cancel()
|
log.info(
|
||||||
if tasks:
|
f"Waiting for remaining rpc tasks to complete {tasks}")
|
||||||
log.info(
|
await self._no_more_rpc_tasks.wait()
|
||||||
f"Waiting for remaining rpc tasks to complete {tasks}")
|
|
||||||
await self._no_more_rpc_tasks.wait()
|
|
||||||
|
|
||||||
def cancel_server(self) -> None:
|
def cancel_server(self) -> None:
|
||||||
"""Cancel the internal channel server nursery thereby
|
"""Cancel the internal channel server nursery thereby
|
||||||
|
@ -810,7 +841,9 @@ async def find_actor(
|
||||||
sockaddr = await arb_portal.run('self', 'find_actor', name=name)
|
sockaddr = await arb_portal.run('self', 'find_actor', name=name)
|
||||||
# TODO: return portals to all available actors - for now just
|
# TODO: return portals to all available actors - for now just
|
||||||
# the last one that registered
|
# the last one that registered
|
||||||
if sockaddr:
|
if name == 'arbiter' and actor.is_arbiter:
|
||||||
|
raise RuntimeError("The current actor is the arbiter")
|
||||||
|
elif sockaddr:
|
||||||
async with _connect_chan(*sockaddr) as chan:
|
async with _connect_chan(*sockaddr) as chan:
|
||||||
async with open_portal(chan) as portal:
|
async with open_portal(chan) as portal:
|
||||||
yield portal
|
yield portal
|
||||||
|
|
Loading…
Reference in New Issue