Merge pull request #50 from tgoodlet/remote_task_cancelling
Remote task cancellingremote_module_errors
commit
aa479d64b0
|
@ -13,6 +13,12 @@ async def stream_seq(sequence):
|
||||||
yield i
|
yield i
|
||||||
await trio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
# block indefinitely waiting to be cancelled by ``aclose()`` call
|
||||||
|
with trio.open_cancel_scope() as cs:
|
||||||
|
await trio.sleep(float('inf'))
|
||||||
|
assert 0
|
||||||
|
assert cs.cancelled_caught
|
||||||
|
|
||||||
|
|
||||||
async def stream_from_single_subactor():
|
async def stream_from_single_subactor():
|
||||||
"""Verify we can spawn a daemon actor and retrieve streamed data.
|
"""Verify we can spawn a daemon actor and retrieve streamed data.
|
||||||
|
@ -37,17 +43,22 @@ async def stream_from_single_subactor():
|
||||||
)
|
)
|
||||||
# it'd sure be nice to have an asyncitertools here...
|
# it'd sure be nice to have an asyncitertools here...
|
||||||
iseq = iter(seq)
|
iseq = iter(seq)
|
||||||
|
ival = next(iseq)
|
||||||
async for val in agen:
|
async for val in agen:
|
||||||
assert val == next(iseq)
|
assert val == ival
|
||||||
# TODO: test breaking the loop (should it kill the
|
try:
|
||||||
# far end?)
|
ival = next(iseq)
|
||||||
# break
|
except StopIteration:
|
||||||
# terminate far-end async-gen
|
# should cancel far end task which will be
|
||||||
# await gen.asend(None)
|
# caught and no error is raised
|
||||||
# break
|
await agen.aclose()
|
||||||
|
|
||||||
# stop all spawned subactors
|
await trio.sleep(0.3)
|
||||||
await portal.cancel_actor()
|
try:
|
||||||
|
await agen.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# stop all spawned subactors
|
||||||
|
await portal.cancel_actor()
|
||||||
# await nursery.cancel()
|
# await nursery.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ async def _invoke(
|
||||||
# TODO: we should really support a proper
|
# TODO: we should really support a proper
|
||||||
# `StopAsyncIteration` system here for returning a final
|
# `StopAsyncIteration` system here for returning a final
|
||||||
# value if desired
|
# value if desired
|
||||||
await chan.send({'stop': None, 'cid': cid})
|
await chan.send({'stop': True, 'cid': cid})
|
||||||
else:
|
else:
|
||||||
if treat_as_gen:
|
if treat_as_gen:
|
||||||
await chan.send({'functype': 'asyncgen', 'cid': cid})
|
await chan.send({'functype': 'asyncgen', 'cid': cid})
|
||||||
|
@ -117,7 +117,7 @@ async def _invoke(
|
||||||
if not cs.cancelled_caught:
|
if not cs.cancelled_caught:
|
||||||
# task was not cancelled so we can instruct the
|
# task was not cancelled so we can instruct the
|
||||||
# far end async gen to tear down
|
# far end async gen to tear down
|
||||||
await chan.send({'stop': None, 'cid': cid})
|
await chan.send({'stop': True, 'cid': cid})
|
||||||
else:
|
else:
|
||||||
await chan.send({'functype': 'asyncfunction', 'cid': cid})
|
await chan.send({'functype': 'asyncfunction', 'cid': cid})
|
||||||
with trio.open_cancel_scope() as cs:
|
with trio.open_cancel_scope() as cs:
|
||||||
|
@ -141,7 +141,7 @@ async def _invoke(
|
||||||
tasks = actor._rpc_tasks.get(chan, None)
|
tasks = actor._rpc_tasks.get(chan, None)
|
||||||
if tasks:
|
if tasks:
|
||||||
try:
|
try:
|
||||||
tasks.remove((cs, func))
|
scope, func = tasks.pop(cid)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If we're cancelled before the task returns then the
|
# If we're cancelled before the task returns then the
|
||||||
# cancel scope will not have been inserted yet
|
# cancel scope will not have been inserted yet
|
||||||
|
@ -197,7 +197,7 @@ class Actor:
|
||||||
self._no_more_rpc_tasks.set()
|
self._no_more_rpc_tasks.set()
|
||||||
self._rpc_tasks: Dict[
|
self._rpc_tasks: Dict[
|
||||||
Channel,
|
Channel,
|
||||||
List[Tuple[trio._core._run.CancelScope, typing.Callable]]
|
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
|
||||||
] = {}
|
] = {}
|
||||||
# map {uids -> {callids -> waiter queues}}
|
# map {uids -> {callids -> waiter queues}}
|
||||||
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
|
||||||
|
@ -294,8 +294,11 @@ class Actor:
|
||||||
# # XXX: is this necessary (GC should do it?)
|
# # XXX: is this necessary (GC should do it?)
|
||||||
if chan.connected():
|
if chan.connected():
|
||||||
log.debug(f"Disconnecting channel {chan}")
|
log.debug(f"Disconnecting channel {chan}")
|
||||||
await chan.send(None)
|
try:
|
||||||
await chan.aclose()
|
await chan.send(None)
|
||||||
|
await chan.aclose()
|
||||||
|
except trio.BrokenResourceError:
|
||||||
|
log.exception(f"Channel for {chan.uid} was already zonked..")
|
||||||
|
|
||||||
async def _push_result(self, actorid, cid: str, msg: dict) -> None:
|
async def _push_result(self, actorid, cid: str, msg: dict) -> None:
|
||||||
"""Push an RPC result to the local consumer's queue.
|
"""Push an RPC result to the local consumer's queue.
|
||||||
|
@ -344,9 +347,10 @@ class Actor:
|
||||||
if msg is None: # terminate sentinel
|
if msg is None: # terminate sentinel
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Cancelling all tasks for {chan} from {chan.uid}")
|
f"Cancelling all tasks for {chan} from {chan.uid}")
|
||||||
for scope, func in self._rpc_tasks.pop(chan, ()):
|
for cid, (scope, func) in self._rpc_tasks.pop(
|
||||||
|
chan, {}
|
||||||
|
).items():
|
||||||
scope.cancel()
|
scope.cancel()
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Msg loop signalled to terminate for"
|
f"Msg loop signalled to terminate for"
|
||||||
f" {chan} from {chan.uid}")
|
f" {chan} from {chan.uid}")
|
||||||
|
@ -354,10 +358,20 @@ class Actor:
|
||||||
log.debug(f"Received msg {msg} from {chan.uid}")
|
log.debug(f"Received msg {msg} from {chan.uid}")
|
||||||
cid = msg.get('cid')
|
cid = msg.get('cid')
|
||||||
if cid:
|
if cid:
|
||||||
# deliver response to local caller/waiter
|
cancel = msg.get('cancel')
|
||||||
await self._push_result(chan.uid, cid, msg)
|
if cancel:
|
||||||
log.debug(
|
# right now this is only implicitly used by
|
||||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
# async generator IPC
|
||||||
|
scope, func = self._rpc_tasks[chan][cid]
|
||||||
|
log.debug(
|
||||||
|
f"Received cancel request for task {cid}"
|
||||||
|
f" from {chan.uid}")
|
||||||
|
scope.cancel()
|
||||||
|
else:
|
||||||
|
# deliver response to local caller/waiter
|
||||||
|
await self._push_result(chan.uid, cid, msg)
|
||||||
|
log.debug(
|
||||||
|
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# process command request
|
# process command request
|
||||||
|
@ -403,7 +417,7 @@ class Actor:
|
||||||
log.info(f"RPC func is {func}")
|
log.info(f"RPC func is {func}")
|
||||||
# store cancel scope such that the rpc task can be
|
# store cancel scope such that the rpc task can be
|
||||||
# cancelled gracefully if requested
|
# cancelled gracefully if requested
|
||||||
self._rpc_tasks.setdefault(chan, []).append((cs, func))
|
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||||
else:
|
else:
|
||||||
|
@ -611,9 +625,9 @@ class Actor:
|
||||||
"""
|
"""
|
||||||
tasks = self._rpc_tasks
|
tasks = self._rpc_tasks
|
||||||
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
|
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
|
||||||
for chan, scopes in tasks.items():
|
for chan, cids2scopes in tasks.items():
|
||||||
log.debug(f"Cancelling all tasks for {chan.uid}")
|
log.debug(f"Cancelling all tasks for {chan.uid}")
|
||||||
for scope, func in scopes:
|
for cid, (scope, func) in cids2scopes.items():
|
||||||
log.debug(f"Cancelling task for {func}")
|
log.debug(f"Cancelling task for {func}")
|
||||||
scope.cancel()
|
scope.cancel()
|
||||||
if tasks:
|
if tasks:
|
||||||
|
|
|
@ -20,7 +20,7 @@ class StreamQueue:
|
||||||
self._agen = self._iter_packets()
|
self._agen = self._iter_packets()
|
||||||
self._laddr = self.stream.socket.getsockname()[:2]
|
self._laddr = self.stream.socket.getsockname()[:2]
|
||||||
self._raddr = self.stream.socket.getpeername()[:2]
|
self._raddr = self.stream.socket.getpeername()[:2]
|
||||||
self._send_lock = trio.Lock()
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||||
"""Yield packets from the underlying stream.
|
"""Yield packets from the underlying stream.
|
||||||
|
|
|
@ -4,7 +4,7 @@ Portal api
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
from typing import Tuple, Any, Dict, Optional
|
from typing import Tuple, Any, Dict, Optional, Set
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
|
@ -65,6 +65,7 @@ class Portal:
|
||||||
self._expect_result: Optional[
|
self._expect_result: Optional[
|
||||||
Tuple[str, Any, str, Dict[str, Any]]
|
Tuple[str, Any, str, Dict[str, Any]]
|
||||||
] = None
|
] = None
|
||||||
|
self._agens: Set[typing.AsyncGenerator] = set()
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
log.debug(f"Closing {self}")
|
log.debug(f"Closing {self}")
|
||||||
|
@ -139,15 +140,19 @@ class Portal:
|
||||||
"Received internal error at portal?")
|
"Received internal error at portal?")
|
||||||
raise unpack_error(msg, self.channel)
|
raise unpack_error(msg, self.channel)
|
||||||
|
|
||||||
except StopAsyncIteration:
|
except GeneratorExit:
|
||||||
log.debug(
|
# for now this msg cancels an ongoing remote task
|
||||||
|
await self.channel.send({'cancel': True, 'cid': cid})
|
||||||
|
log.warn(
|
||||||
f"Cancelling async gen call {cid} to "
|
f"Cancelling async gen call {cid} to "
|
||||||
f"{self.channel.uid}")
|
f"{self.channel.uid}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# TODO: use AsyncExitStack to aclose() all agens
|
# TODO: use AsyncExitStack to aclose() all agens
|
||||||
# on teardown
|
# on teardown
|
||||||
return yield_from_q()
|
agen = yield_from_q()
|
||||||
|
self._agens.add(agen)
|
||||||
|
return agen
|
||||||
|
|
||||||
elif resptype == 'return':
|
elif resptype == 'return':
|
||||||
msg = await q.get()
|
msg = await q.get()
|
||||||
|
@ -267,13 +272,18 @@ async def open_portal(
|
||||||
|
|
||||||
nursery.start_soon(actor._process_messages, channel)
|
nursery.start_soon(actor._process_messages, channel)
|
||||||
portal = Portal(channel)
|
portal = Portal(channel)
|
||||||
yield portal
|
try:
|
||||||
|
yield portal
|
||||||
|
finally:
|
||||||
|
# tear down all async generators
|
||||||
|
for agen in portal._agens:
|
||||||
|
await agen.aclose()
|
||||||
|
|
||||||
# cancel remote channel-msg loop
|
# cancel remote channel-msg loop
|
||||||
if channel.connected():
|
if channel.connected():
|
||||||
await portal.close()
|
await portal.close()
|
||||||
|
|
||||||
# cancel background msg loop task
|
# cancel background msg loop task
|
||||||
nursery.cancel_scope.cancel()
|
nursery.cancel_scope.cancel()
|
||||||
if was_connected:
|
if was_connected:
|
||||||
await channel.aclose()
|
await channel.aclose()
|
||||||
|
|
Loading…
Reference in New Issue