forked from goodboy/tractor
				
			Add `Actor.cancel_task()`
Enable cancelling specific tasks from a peer actor such that when a actor task or the actor itself is cancelled, remotely spawned tasks can also be cancelled. In much that same way that you'd expect a node (task) in the `trio` task tree to cancel any subtasks, actors should be able to cancel any tasks they spawn in separate processes. To enable this: - track rpc tasks in a flat dict keyed by (chan, cid) - store a `is_complete` event to enable waiting on specific tasks to complete - allow for shielding the msg loop inside an internal cancel scope if requested by the caller; there was an issue with `open_portal()` where the channel would be torn down because the current task was cancelled but we still need messaging to continue until the portal block is exited - throw an error if the arbiter tries to find itself for nowcontexts
							parent
							
								
									251ee177fa
								
							
						
					
					
						commit
						03e00886da
					
				|  | @ -140,19 +140,15 @@ async def _invoke( | ||||||
|             task_status.started(err) |             task_status.started(err) | ||||||
|     finally: |     finally: | ||||||
|         # RPC task bookeeping |         # RPC task bookeeping | ||||||
|         tasks = actor._rpc_tasks.get(chan, None) |  | ||||||
|         if tasks: |  | ||||||
|         try: |         try: | ||||||
|                 scope, func = tasks.pop(cid) |             scope, func, is_complete = actor._rpc_tasks.pop((chan, cid)) | ||||||
|             except ValueError: |             is_complete.set() | ||||||
|  |         except KeyError: | ||||||
|             # If we're cancelled before the task returns then the |             # If we're cancelled before the task returns then the | ||||||
|             # cancel scope will not have been inserted yet |             # cancel scope will not have been inserted yet | ||||||
|             log.warn( |             log.warn( | ||||||
|                 f"Task {func} was likely cancelled before it was started") |                 f"Task {func} was likely cancelled before it was started") | ||||||
| 
 | 
 | ||||||
|         if not tasks: |  | ||||||
|             actor._rpc_tasks.pop(chan, None) |  | ||||||
| 
 |  | ||||||
|         if not actor._rpc_tasks: |         if not actor._rpc_tasks: | ||||||
|             log.info(f"All RPC tasks have completed") |             log.info(f"All RPC tasks have completed") | ||||||
|             actor._no_more_rpc_tasks.set() |             actor._no_more_rpc_tasks.set() | ||||||
|  | @ -197,9 +193,10 @@ class Actor: | ||||||
| 
 | 
 | ||||||
|         self._no_more_rpc_tasks = trio.Event() |         self._no_more_rpc_tasks = trio.Event() | ||||||
|         self._no_more_rpc_tasks.set() |         self._no_more_rpc_tasks.set() | ||||||
|  |         # (chan, cid) -> (cancel_scope, func) | ||||||
|         self._rpc_tasks: Dict[ |         self._rpc_tasks: Dict[ | ||||||
|             Channel, |             Tuple[Channel, str], | ||||||
|             Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]] |             Tuple[trio._core._run.CancelScope, typing.Callable, trio.Event] | ||||||
|         ] = {} |         ] = {} | ||||||
|         # map {uids -> {callids -> waiter queues}} |         # map {uids -> {callids -> waiter queues}} | ||||||
|         self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} |         self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {} | ||||||
|  | @ -268,7 +265,7 @@ class Actor: | ||||||
|             log.warning( |             log.warning( | ||||||
|                 f"already have channel(s) for {uid}:{chans}?" |                 f"already have channel(s) for {uid}:{chans}?" | ||||||
|             ) |             ) | ||||||
|         log.debug(f"Registered {chan} for {uid}") |         log.trace(f"Registered {chan} for {uid}") | ||||||
|         # append new channel |         # append new channel | ||||||
|         self._peers[uid].append(chan) |         self._peers[uid].append(chan) | ||||||
| 
 | 
 | ||||||
|  | @ -295,8 +292,9 @@ class Actor: | ||||||
|             if chan.connected(): |             if chan.connected(): | ||||||
|                 log.debug(f"Disconnecting channel {chan}") |                 log.debug(f"Disconnecting channel {chan}") | ||||||
|                 try: |                 try: | ||||||
|  |                     # send our msg loop terminate sentinel | ||||||
|                     await chan.send(None) |                     await chan.send(None) | ||||||
|                     await chan.aclose() |                     # await chan.aclose() | ||||||
|                 except trio.BrokenResourceError: |                 except trio.BrokenResourceError: | ||||||
|                     log.exception( |                     log.exception( | ||||||
|                         f"Channel for {chan.uid} was already zonked..") |                         f"Channel for {chan.uid} was already zonked..") | ||||||
|  | @ -334,7 +332,10 @@ class Actor: | ||||||
|         return cid, q |         return cid, q | ||||||
| 
 | 
 | ||||||
|     async def _process_messages( |     async def _process_messages( | ||||||
|         self, chan: Channel, treat_as_gen: bool = False |         self, chan: Channel, | ||||||
|  |         treat_as_gen: bool = False, | ||||||
|  |         shield: bool = False, | ||||||
|  |         task_status=trio.TASK_STATUS_IGNORED, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         """Process messages for the channel async-RPC style. |         """Process messages for the channel async-RPC style. | ||||||
| 
 | 
 | ||||||
|  | @ -342,33 +343,30 @@ class Actor: | ||||||
|         """ |         """ | ||||||
|         # TODO: once https://github.com/python-trio/trio/issues/467 gets |         # TODO: once https://github.com/python-trio/trio/issues/467 gets | ||||||
|         # worked out we'll likely want to use that! |         # worked out we'll likely want to use that! | ||||||
|  |         msg = None | ||||||
|         log.debug(f"Entering msg loop for {chan} from {chan.uid}") |         log.debug(f"Entering msg loop for {chan} from {chan.uid}") | ||||||
|         try: |         try: | ||||||
|  |             # internal scope allows for keeping this message | ||||||
|  |             # loop running despite the current task having been | ||||||
|  |             # cancelled (eg. `open_portal()` may call this method from | ||||||
|  |             # a locally spawned task) | ||||||
|  |             with trio.open_cancel_scope(shield=shield) as cs: | ||||||
|  |                 task_status.started(cs) | ||||||
|                 async for msg in chan: |                 async for msg in chan: | ||||||
|                 if msg is None:  # terminate sentinel |                     if msg is None:  # loop terminate sentinel | ||||||
|                         log.debug( |                         log.debug( | ||||||
|                             f"Cancelling all tasks for {chan} from {chan.uid}") |                             f"Cancelling all tasks for {chan} from {chan.uid}") | ||||||
|                     for cid, (scope, func) in self._rpc_tasks.pop( |                         for (channel, cid) in self._rpc_tasks: | ||||||
|                         chan, {} |                             if channel is chan: | ||||||
|                     ).items(): |                                 self.cancel_task(cid, Context(channel, cid)) | ||||||
|                         scope.cancel() |  | ||||||
|                         log.debug( |                         log.debug( | ||||||
|                                 f"Msg loop signalled to terminate for" |                                 f"Msg loop signalled to terminate for" | ||||||
|                                 f" {chan} from {chan.uid}") |                                 f" {chan} from {chan.uid}") | ||||||
|                         break |                         break | ||||||
|  | 
 | ||||||
|                     log.debug(f"Received msg {msg} from {chan.uid}") |                     log.debug(f"Received msg {msg} from {chan.uid}") | ||||||
|                     cid = msg.get('cid') |                     cid = msg.get('cid') | ||||||
|                     if cid: |                     if cid: | ||||||
|                     cancel = msg.get('cancel') |  | ||||||
|                     if cancel: |  | ||||||
|                         # right now this is only implicitly used by |  | ||||||
|                         # async generator IPC |  | ||||||
|                         scope, func = self._rpc_tasks[chan][cid] |  | ||||||
|                         log.debug( |  | ||||||
|                             f"Received cancel request for task {cid}" |  | ||||||
|                             f" from {chan.uid}") |  | ||||||
|                         scope.cancel() |  | ||||||
|                     else: |  | ||||||
|                         # deliver response to local caller/waiter |                         # deliver response to local caller/waiter | ||||||
|                         await self._push_result(chan.uid, cid, msg) |                         await self._push_result(chan.uid, cid, msg) | ||||||
|                         log.debug( |                         log.debug( | ||||||
|  | @ -414,14 +412,16 @@ class Actor: | ||||||
|                     # deadlock and other weird behaviour) |                     # deadlock and other weird behaviour) | ||||||
|                     if func != self.cancel: |                     if func != self.cancel: | ||||||
|                         if isinstance(cs, Exception): |                         if isinstance(cs, Exception): | ||||||
|                         log.warn(f"Task for RPC func {func} failed with {cs}") |                             log.warn(f"Task for RPC func {func} failed with" | ||||||
|  |                                      f"{cs}") | ||||||
|                         else: |                         else: | ||||||
|                             # mark that we have ongoing rpc tasks |                             # mark that we have ongoing rpc tasks | ||||||
|                             self._no_more_rpc_tasks.clear() |                             self._no_more_rpc_tasks.clear() | ||||||
|                             log.info(f"RPC func is {func}") |                             log.info(f"RPC func is {func}") | ||||||
|                             # store cancel scope such that the rpc task can be |                             # store cancel scope such that the rpc task can be | ||||||
|                             # cancelled gracefully if requested |                             # cancelled gracefully if requested | ||||||
|                         self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func) |                             self._rpc_tasks[(chan, cid)] = ( | ||||||
|  |                                 cs, func, trio.Event()) | ||||||
|                     log.debug( |                     log.debug( | ||||||
|                         f"Waiting on next msg for {chan} from {chan.uid}") |                         f"Waiting on next msg for {chan} from {chan.uid}") | ||||||
|                 else: |                 else: | ||||||
|  | @ -439,8 +439,14 @@ class Actor: | ||||||
|             raise |             raise | ||||||
|             # if this is the `MainProcess` we expect the error broadcasting |             # if this is the `MainProcess` we expect the error broadcasting | ||||||
|             # above to trigger an error at consuming portal "checkpoints" |             # above to trigger an error at consuming portal "checkpoints" | ||||||
|  |         except trio.Cancelled: | ||||||
|  |             # debugging only | ||||||
|  |             log.debug("Msg loop was cancelled") | ||||||
|  |             raise | ||||||
|         finally: |         finally: | ||||||
|             log.debug(f"Exiting msg loop for {chan} from {chan.uid}") |             log.debug( | ||||||
|  |                 f"Exiting msg loop for {chan} from {chan.uid} " | ||||||
|  |                 f"with last msg:\n{msg}") | ||||||
| 
 | 
 | ||||||
|     def _fork_main( |     def _fork_main( | ||||||
|         self, |         self, | ||||||
|  | @ -541,8 +547,7 @@ class Actor: | ||||||
|             if self._parent_chan: |             if self._parent_chan: | ||||||
|                 try: |                 try: | ||||||
|                     # internal error so ship to parent without cid |                     # internal error so ship to parent without cid | ||||||
|                     await self._parent_chan.send( |                     await self._parent_chan.send(pack_error(err)) | ||||||
|                         pack_error(err)) |  | ||||||
|                 except trio.ClosedResourceError: |                 except trio.ClosedResourceError: | ||||||
|                     log.error( |                     log.error( | ||||||
|                         f"Failed to ship error to parent " |                         f"Failed to ship error to parent " | ||||||
|  | @ -627,18 +632,44 @@ class Actor: | ||||||
|         self.cancel_server() |         self.cancel_server() | ||||||
|         self._root_nursery.cancel_scope.cancel() |         self._root_nursery.cancel_scope.cancel() | ||||||
| 
 | 
 | ||||||
|  |     async def cancel_task(self, cid, ctx): | ||||||
|  |         """Cancel a local task. | ||||||
|  | 
 | ||||||
|  |         Note this method will be treated as a streaming funciton | ||||||
|  |         by remote actor-callers due to the declaration of ``ctx`` | ||||||
|  |         in the signature (for now). | ||||||
|  |         """ | ||||||
|  |         # right now this is only implicitly called by | ||||||
|  |         # streaming IPC but it should be called | ||||||
|  |         # to cancel any remotely spawned task | ||||||
|  |         chan = ctx.chan | ||||||
|  |         # the ``dict.get()`` ensures the requested task to be cancelled | ||||||
|  |         # was indeed spawned by a request from this channel | ||||||
|  |         scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)] | ||||||
|  |         log.debug( | ||||||
|  |             f"Cancelling task:\ncid: {cid}\nfunc: {func}\n" | ||||||
|  |             f"peer: {chan.uid}\n") | ||||||
|  | 
 | ||||||
|  |         # if func is self.cancel_task: | ||||||
|  |         #     return | ||||||
|  | 
 | ||||||
|  |         scope.cancel() | ||||||
|  |         # wait for _invoke to mark the task complete | ||||||
|  |         await is_complete.wait() | ||||||
|  |         log.debug( | ||||||
|  |             f"Sucessfully cancelled task:\ncid: {cid}\nfunc: {func}\n" | ||||||
|  |             f"peer: {chan.uid}\n") | ||||||
|  | 
 | ||||||
|     async def cancel_rpc_tasks(self) -> None: |     async def cancel_rpc_tasks(self) -> None: | ||||||
|         """Cancel all existing RPC responder tasks using the cancel scope |         """Cancel all existing RPC responder tasks using the cancel scope | ||||||
|         registered for each. |         registered for each. | ||||||
|         """ |         """ | ||||||
|         tasks = self._rpc_tasks |         tasks = self._rpc_tasks | ||||||
|         log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}") |         log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ") | ||||||
|         for chan, cids2scopes in tasks.items(): |         for (chan, cid) in tasks.copy(): | ||||||
|             log.debug(f"Cancelling all tasks for {chan.uid}") |             # TODO: this should really done in a nursery batch | ||||||
|             for cid, (scope, func) in cids2scopes.items(): |             await self.cancel_task(cid, Context(chan, cid)) | ||||||
|                 log.debug(f"Cancelling task for {func}") |         # if tasks: | ||||||
|                 scope.cancel() |  | ||||||
|         if tasks: |  | ||||||
|         log.info( |         log.info( | ||||||
|             f"Waiting for remaining rpc tasks to complete {tasks}") |             f"Waiting for remaining rpc tasks to complete {tasks}") | ||||||
|         await self._no_more_rpc_tasks.wait() |         await self._no_more_rpc_tasks.wait() | ||||||
|  | @ -810,7 +841,9 @@ async def find_actor( | ||||||
|         sockaddr = await arb_portal.run('self', 'find_actor', name=name) |         sockaddr = await arb_portal.run('self', 'find_actor', name=name) | ||||||
|         # TODO: return portals to all available actors - for now just |         # TODO: return portals to all available actors - for now just | ||||||
|         # the last one that registered |         # the last one that registered | ||||||
|         if sockaddr: |         if name == 'arbiter' and actor.is_arbiter: | ||||||
|  |             raise RuntimeError("The current actor is the arbiter") | ||||||
|  |         elif sockaddr: | ||||||
|             async with _connect_chan(*sockaddr) as chan: |             async with _connect_chan(*sockaddr) as chan: | ||||||
|                 async with open_portal(chan) as portal: |                 async with open_portal(chan) as portal: | ||||||
|                     yield portal |                     yield portal | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue