Merge pull request #50 from tgoodlet/remote_task_cancelling

Remote task cancelling
remote_module_errors
goodboy 2019-01-01 15:22:53 -05:00 committed by GitHub
commit aa479d64b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 37 deletions

View File

@ -13,6 +13,12 @@ async def stream_seq(sequence):
yield i
await trio.sleep(0.1)
# block indefinitely waiting to be cancelled by ``aclose()`` call
with trio.open_cancel_scope() as cs:
await trio.sleep(float('inf'))
assert 0
assert cs.cancelled_caught
async def stream_from_single_subactor():
"""Verify we can spawn a daemon actor and retrieve streamed data.
@ -37,15 +43,20 @@ async def stream_from_single_subactor():
)
# it'd sure be nice to have an asyncitertools here...
iseq = iter(seq)
ival = next(iseq)
async for val in agen:
assert val == next(iseq)
# TODO: test breaking the loop (should it kill the
# far end?)
# break
# terminate far-end async-gen
# await gen.asend(None)
# break
assert val == ival
try:
ival = next(iseq)
except StopIteration:
# should cancel far end task which will be
# caught and no error is raised
await agen.aclose()
await trio.sleep(0.3)
try:
await agen.__anext__()
except StopAsyncIteration:
# stop all spawned subactors
await portal.cancel_actor()
# await nursery.cancel()

View File

@ -103,7 +103,7 @@ async def _invoke(
# TODO: we should really support a proper
# `StopAsyncIteration` system here for returning a final
# value if desired
await chan.send({'stop': None, 'cid': cid})
await chan.send({'stop': True, 'cid': cid})
else:
if treat_as_gen:
await chan.send({'functype': 'asyncgen', 'cid': cid})
@ -117,7 +117,7 @@ async def _invoke(
if not cs.cancelled_caught:
# task was not cancelled so we can instruct the
# far end async gen to tear down
await chan.send({'stop': None, 'cid': cid})
await chan.send({'stop': True, 'cid': cid})
else:
await chan.send({'functype': 'asyncfunction', 'cid': cid})
with trio.open_cancel_scope() as cs:
@ -141,7 +141,7 @@ async def _invoke(
tasks = actor._rpc_tasks.get(chan, None)
if tasks:
try:
tasks.remove((cs, func))
scope, func = tasks.pop(cid)
except ValueError:
# If we're cancelled before the task returns then the
# cancel scope will not have been inserted yet
@ -197,7 +197,7 @@ class Actor:
self._no_more_rpc_tasks.set()
self._rpc_tasks: Dict[
Channel,
List[Tuple[trio._core._run.CancelScope, typing.Callable]]
Dict[str, Tuple[trio._core._run.CancelScope, typing.Callable]]
] = {}
# map {uids -> {callids -> waiter queues}}
self._actors2calls: Dict[Tuple[str, str], Dict[str, trio.Queue]] = {}
@ -294,8 +294,11 @@ class Actor:
# # XXX: is this necessary (GC should do it?)
if chan.connected():
log.debug(f"Disconnecting channel {chan}")
try:
await chan.send(None)
await chan.aclose()
except trio.BrokenResourceError:
log.exception(f"Channel for {chan.uid} was already zonked..")
async def _push_result(self, actorid, cid: str, msg: dict) -> None:
"""Push an RPC result to the local consumer's queue.
@ -344,9 +347,10 @@ class Actor:
if msg is None: # terminate sentinel
log.debug(
f"Cancelling all tasks for {chan} from {chan.uid}")
for scope, func in self._rpc_tasks.pop(chan, ()):
for cid, (scope, func) in self._rpc_tasks.pop(
chan, {}
).items():
scope.cancel()
log.debug(
f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}")
@ -354,6 +358,16 @@ class Actor:
log.debug(f"Received msg {msg} from {chan.uid}")
cid = msg.get('cid')
if cid:
cancel = msg.get('cancel')
if cancel:
# right now this is only implicitly used by
# async generator IPC
scope, func = self._rpc_tasks[chan][cid]
log.debug(
f"Received cancel request for task {cid}"
f" from {chan.uid}")
scope.cancel()
else:
# deliver response to local caller/waiter
await self._push_result(chan.uid, cid, msg)
log.debug(
@ -403,7 +417,7 @@ class Actor:
log.info(f"RPC func is {func}")
# store cancel scope such that the rpc task can be
# cancelled gracefully if requested
self._rpc_tasks.setdefault(chan, []).append((cs, func))
self._rpc_tasks.setdefault(chan, {})[cid] = (cs, func)
log.debug(
f"Waiting on next msg for {chan} from {chan.uid}")
else:
@ -611,9 +625,9 @@ class Actor:
"""
tasks = self._rpc_tasks
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks}")
for chan, scopes in tasks.items():
for chan, cids2scopes in tasks.items():
log.debug(f"Cancelling all tasks for {chan.uid}")
for scope, func in scopes:
for cid, (scope, func) in cids2scopes.items():
log.debug(f"Cancelling task for {func}")
scope.cancel()
if tasks:

View File

@ -20,7 +20,7 @@ class StreamQueue:
self._agen = self._iter_packets()
self._laddr = self.stream.socket.getsockname()[:2]
self._raddr = self.stream.socket.getpeername()[:2]
self._send_lock = trio.Lock()
self._send_lock = trio.StrictFIFOLock()
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.

View File

@ -4,7 +4,7 @@ Portal api
import importlib
import inspect
import typing
from typing import Tuple, Any, Dict, Optional
from typing import Tuple, Any, Dict, Optional, Set
import trio
from async_generator import asynccontextmanager
@ -65,6 +65,7 @@ class Portal:
self._expect_result: Optional[
Tuple[str, Any, str, Dict[str, Any]]
] = None
self._agens: Set[typing.AsyncGenerator] = set()
async def aclose(self) -> None:
log.debug(f"Closing {self}")
@ -139,15 +140,19 @@ class Portal:
"Received internal error at portal?")
raise unpack_error(msg, self.channel)
except StopAsyncIteration:
log.debug(
except GeneratorExit:
# for now this msg cancels an ongoing remote task
await self.channel.send({'cancel': True, 'cid': cid})
log.warn(
f"Cancelling async gen call {cid} to "
f"{self.channel.uid}")
raise
# TODO: use AsyncExitStack to aclose() all agens
# on teardown
return yield_from_q()
agen = yield_from_q()
self._agens.add(agen)
return agen
elif resptype == 'return':
msg = await q.get()
@ -267,7 +272,12 @@ async def open_portal(
nursery.start_soon(actor._process_messages, channel)
portal = Portal(channel)
try:
yield portal
finally:
# tear down all async generators
for agen in portal._agens:
await agen.aclose()
# cancel remote channel-msg loop
if channel.connected():