Add support for cancelling remote tasks via a msg
parent
c0cdb3945a
commit
4dccb44c67
|
@ -103,7 +103,7 @@ async def _invoke(
|
|||
# TODO: we should really support a proper
|
||||
# `StopAsyncIteration` system here for returning a final
|
||||
# value if desired
|
||||
await chan.send({'stop': None, 'cid': cid})
|
||||
await chan.send({'stop': True, 'cid': cid})
|
||||
else:
|
||||
if treat_as_gen:
|
||||
await chan.send({'functype': 'asyncgen', 'cid': cid})
|
||||
|
@ -117,7 +117,7 @@ async def _invoke(
|
|||
if not cs.cancelled_caught:
|
||||
# task was not cancelled so we can instruct the
|
||||
# far end async gen to tear down
|
||||
await chan.send({'stop': None, 'cid': cid})
|
||||
await chan.send({'stop': True, 'cid': cid})
|
||||
else:
|
||||
await chan.send({'functype': 'asyncfunction', 'cid': cid})
|
||||
with trio.open_cancel_scope() as cs:
|
||||
|
@ -141,7 +141,7 @@ async def _invoke(
|
|||
tasks = actor._rpc_tasks.get(chan, None)
|
||||
if tasks:
|
||||
try:
|
||||
tasks.remove((cs, func))
|
||||
scope, func = tasks.pop(cid)
|
||||
except ValueError:
|
||||
# If we're cancelled before the task returns then the
|
||||
# cancel scope will not have been inserted yet
|
||||
|
@ -197,7 +197,7 @@ class Actor:
|
|||
self._no_more_rpc_tasks.set()
|
||||
self._rpc_tasks: Dict[
|
||||
Channel,
|
||||
List[Tuple[trio._core._run.CancelScope, typing.Callable]]
|
||||
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
|
||||
] = {}
|
||||
# map {uids -> {callids -> waiter queues}}
|
||||
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
||||
|
@ -344,9 +344,10 @@ class Actor:
|
|||
if msg is None: # terminate sentinel
|
||||
log.debug(
|
||||
f"Cancelling all tasks for {chan} from {chan.uid}")
|
||||
for scope, func in self._rpc_tasks.pop(chan, ()):
|
||||
for cid, (scope, func) in self._rpc_tasks.pop(
|
||||
chan, {}
|
||||
).items():
|
||||
scope.cancel()
|
||||
|
||||
log.debug(
|
||||
f"Msg loop signalled to terminate for"
|
||||
f" {chan} from {chan.uid}")
|
||||
|
@ -354,6 +355,16 @@ class Actor:
|
|||
log.debug(f"Received msg {msg} from {chan.uid}")
|
||||
cid = msg.get('cid')
|
||||
if cid:
|
||||
cancel = msg.get('cancel')
|
||||
if cancel:
|
||||
# right now this is only implicitly used by
|
||||
# async generator IPC
|
||||
scope, func = self._rpc_tasks[chan][cid]
|
||||
log.debug(
|
||||
f"Received cancel request for task {cid}"
|
||||
f" from {chan.uid}")
|
||||
scope.cancel()
|
||||
else:
|
||||
# deliver response to local caller/waiter
|
||||
await self._push_result(chan.uid, cid, msg)
|
||||
log.debug(
|
||||
|
@ -403,7 +414,7 @@ class Actor:
|
|||
log.info(f"RPC func is {func}")
|
||||
# store cancel scope such that the rpc task can be
|
||||
# cancelled gracefully if requested
|
||||
self._rpc_tasks.setdefault(chan, []).append((cs, func))
|
||||
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
|
||||
log.debug(
|
||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||
else:
|
||||
|
@ -611,9 +622,9 @@ class Actor:
|
|||
"""
|
||||
tasks = self._rpc_tasks
|
||||
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
|
||||
for chan, scopes in tasks.items():
|
||||
for chan, cids2scopes in tasks.items():
|
||||
log.debug(f"Cancelling all tasks for {chan.uid}")
|
||||
for scope, func in scopes:
|
||||
for cid, (scope, func) in cids2scopes.items():
|
||||
log.debug(f"Cancelling task for {func}")
|
||||
scope.cancel()
|
||||
if tasks:
|
||||
|
|
Loading…
Reference in New Issue