Provide each task's cancel scope to every `Context`
This begins moving toward explicitly decorated "streaming functions" instead of checking for a `ctx` arg in the signature. - provide each context with its task's top level `trio.CancelScope` such that tasks can cancel themselves explictly if needed via calling `Context.cancel_scope()` - make `Actor.cancel_task()` a private method (`_cancel_task()`) and handle remote rpc calls specially such that the caller does not need to provide the `chan` argument; non-primitive types can't be passed on the wire and we don't want the client actor be require knowledge of the channel instance the request is associated with. This also ties into how we're tracking tasks right now (`Actor._rpc_tasks` is keyed by the call id, a UUID, *plus* the channel). - make `_do_handshake` a private actor method - use UUID version 4stream_functions
parent
ac4a025aa5
commit
2aa6ffce60
|
@ -23,7 +23,6 @@ from ._exceptions import (
|
||||||
from ._portal import (
|
from ._portal import (
|
||||||
Portal,
|
Portal,
|
||||||
open_portal,
|
open_portal,
|
||||||
_do_handshake,
|
|
||||||
LocalPortal,
|
LocalPortal,
|
||||||
)
|
)
|
||||||
from . import _state
|
from . import _state
|
||||||
|
@ -50,7 +49,8 @@ async def _invoke(
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
treat_as_gen = False
|
treat_as_gen = False
|
||||||
cs = None
|
cs = None
|
||||||
ctx = Context(chan, cid)
|
cancel_scope = trio.CancelScope()
|
||||||
|
ctx = Context(chan, cid, cancel_scope)
|
||||||
if 'ctx' in sig.parameters:
|
if 'ctx' in sig.parameters:
|
||||||
kwargs['ctx'] = ctx
|
kwargs['ctx'] = ctx
|
||||||
# TODO: eventually we want to be more stringent
|
# TODO: eventually we want to be more stringent
|
||||||
|
@ -73,7 +73,7 @@ async def _invoke(
|
||||||
not is_async_gen_partial
|
not is_async_gen_partial
|
||||||
):
|
):
|
||||||
await chan.send({'functype': 'function', 'cid': cid})
|
await chan.send({'functype': 'function', 'cid': cid})
|
||||||
with trio.CancelScope() as cs:
|
with cancel_scope as cs:
|
||||||
task_status.started(cs)
|
task_status.started(cs)
|
||||||
await chan.send({'return': func(**kwargs), 'cid': cid})
|
await chan.send({'return': func(**kwargs), 'cid': cid})
|
||||||
else:
|
else:
|
||||||
|
@ -88,7 +88,7 @@ async def _invoke(
|
||||||
# have to properly handle the closing (aclosing)
|
# have to properly handle the closing (aclosing)
|
||||||
# of the async gen in order to be sure the cancel
|
# of the async gen in order to be sure the cancel
|
||||||
# is propagated!
|
# is propagated!
|
||||||
with trio.CancelScope() as cs:
|
with cancel_scope as cs:
|
||||||
task_status.started(cs)
|
task_status.started(cs)
|
||||||
async with aclosing(coro) as agen:
|
async with aclosing(coro) as agen:
|
||||||
async for item in agen:
|
async for item in agen:
|
||||||
|
@ -113,7 +113,7 @@ async def _invoke(
|
||||||
# back values like an async-generator would but must
|
# back values like an async-generator would but must
|
||||||
# manualy construct the response dict-packet-responses as
|
# manualy construct the response dict-packet-responses as
|
||||||
# above
|
# above
|
||||||
with trio.CancelScope() as cs:
|
with cancel_scope as cs:
|
||||||
task_status.started(cs)
|
task_status.started(cs)
|
||||||
await coro
|
await coro
|
||||||
if not cs.cancelled_caught:
|
if not cs.cancelled_caught:
|
||||||
|
@ -122,7 +122,7 @@ async def _invoke(
|
||||||
await chan.send({'stop': True, 'cid': cid})
|
await chan.send({'stop': True, 'cid': cid})
|
||||||
else:
|
else:
|
||||||
await chan.send({'functype': 'asyncfunction', 'cid': cid})
|
await chan.send({'functype': 'asyncfunction', 'cid': cid})
|
||||||
with trio.CancelScope() as cs:
|
with cancel_scope as cs:
|
||||||
task_status.started(cs)
|
task_status.started(cs)
|
||||||
await chan.send({'return': await coro, 'cid': cid})
|
await chan.send({'return': await coro, 'cid': cid})
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
@ -174,7 +174,7 @@ class Actor:
|
||||||
arbiter_addr: Optional[Tuple[str, int]] = None,
|
arbiter_addr: Optional[Tuple[str, int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.uid = (name, uid or str(uuid.uuid1()))
|
self.uid = (name, uid or str(uuid.uuid4()))
|
||||||
self.rpc_module_paths = rpc_module_paths
|
self.rpc_module_paths = rpc_module_paths
|
||||||
self._mods: dict = {}
|
self._mods: dict = {}
|
||||||
# TODO: consider making this a dynamically defined
|
# TODO: consider making this a dynamically defined
|
||||||
|
@ -247,7 +247,7 @@ class Actor:
|
||||||
|
|
||||||
# send/receive initial handshake response
|
# send/receive initial handshake response
|
||||||
try:
|
try:
|
||||||
uid = await _do_handshake(self, chan)
|
uid = await self._do_handshake(chan)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
log.warning(f"Channel {chan} failed to handshake")
|
log.warning(f"Channel {chan} failed to handshake")
|
||||||
return
|
return
|
||||||
|
@ -351,7 +351,7 @@ class Actor:
|
||||||
caller id and a ``trio.Queue`` that can be used to wait for
|
caller id and a ``trio.Queue`` that can be used to wait for
|
||||||
responses delivered by the local message processing loop.
|
responses delivered by the local message processing loop.
|
||||||
"""
|
"""
|
||||||
cid = str(uuid.uuid1())
|
cid = str(uuid.uuid4())
|
||||||
assert chan.uid
|
assert chan.uid
|
||||||
recv_chan = self.get_memchans(chan.uid, cid)
|
recv_chan = self.get_memchans(chan.uid, cid)
|
||||||
log.debug(f"Sending cmd to {chan.uid}: {ns}.{func}({kwargs})")
|
log.debug(f"Sending cmd to {chan.uid}: {ns}.{func}({kwargs})")
|
||||||
|
@ -373,11 +373,12 @@ class Actor:
|
||||||
msg = None
|
msg = None
|
||||||
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
|
log.debug(f"Entering msg loop for {chan} from {chan.uid}")
|
||||||
try:
|
try:
|
||||||
# internal scope allows for keeping this message
|
with trio.CancelScope(shield=shield) as cs:
|
||||||
|
# this internal scope allows for keeping this message
|
||||||
# loop running despite the current task having been
|
# loop running despite the current task having been
|
||||||
# cancelled (eg. `open_portal()` may call this method from
|
# cancelled (eg. `open_portal()` may call this method from
|
||||||
# a locally spawned task)
|
# a locally spawned task) and recieve this scope using
|
||||||
with trio.CancelScope(shield=shield) as cs:
|
# ``scope = Nursery.start()``
|
||||||
task_status.started(cs)
|
task_status.started(cs)
|
||||||
async for msg in chan:
|
async for msg in chan:
|
||||||
if msg is None: # loop terminate sentinel
|
if msg is None: # loop terminate sentinel
|
||||||
|
@ -385,7 +386,7 @@ class Actor:
|
||||||
f"Cancelling all tasks for {chan} from {chan.uid}")
|
f"Cancelling all tasks for {chan} from {chan.uid}")
|
||||||
for (channel, cid) in self._rpc_tasks:
|
for (channel, cid) in self._rpc_tasks:
|
||||||
if channel is chan:
|
if channel is chan:
|
||||||
self.cancel_task(cid, Context(channel, cid))
|
self._cancel_task(cid, channel)
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Msg loop signalled to terminate for"
|
f"Msg loop signalled to terminate for"
|
||||||
f" {chan} from {chan.uid}")
|
f" {chan} from {chan.uid}")
|
||||||
|
@ -419,6 +420,16 @@ class Actor:
|
||||||
f"{ns}.{funcname}({kwargs})")
|
f"{ns}.{funcname}({kwargs})")
|
||||||
if ns == 'self':
|
if ns == 'self':
|
||||||
func = getattr(self, funcname)
|
func = getattr(self, funcname)
|
||||||
|
if funcname == '_cancel_task':
|
||||||
|
# XXX: a special case is made here for
|
||||||
|
# remote calls since we don't want the
|
||||||
|
# remote actor have to know which channel
|
||||||
|
# the task is associated with and we can't
|
||||||
|
# pass non-primitive types between actors.
|
||||||
|
# This means you can use:
|
||||||
|
# Portal.run('self', '_cancel_task, cid=did)
|
||||||
|
# without passing the `chan` arg.
|
||||||
|
kwargs['chan'] = chan
|
||||||
else:
|
else:
|
||||||
# complain to client about restricted modules
|
# complain to client about restricted modules
|
||||||
try:
|
try:
|
||||||
|
@ -537,7 +548,7 @@ class Actor:
|
||||||
)
|
)
|
||||||
await chan.connect()
|
await chan.connect()
|
||||||
# initial handshake, report who we are, who they are
|
# initial handshake, report who we are, who they are
|
||||||
await _do_handshake(self, chan)
|
await self._do_handshake(chan)
|
||||||
except OSError: # failed to connect
|
except OSError: # failed to connect
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Failed to connect to parent @ {parent_addr},"
|
f"Failed to connect to parent @ {parent_addr},"
|
||||||
|
@ -661,21 +672,20 @@ class Actor:
|
||||||
self.cancel_server()
|
self.cancel_server()
|
||||||
self._root_nursery.cancel_scope.cancel()
|
self._root_nursery.cancel_scope.cancel()
|
||||||
|
|
||||||
async def cancel_task(self, cid, ctx):
|
async def _cancel_task(self, cid, chan):
|
||||||
"""Cancel a local task.
|
"""Cancel a local task by call-id / channel.
|
||||||
|
|
||||||
Note this method will be treated as a streaming funciton
|
Note this method will be treated as a streaming function
|
||||||
by remote actor-callers due to the declaration of ``ctx``
|
by remote actor-callers due to the declaration of ``ctx``
|
||||||
in the signature (for now).
|
in the signature (for now).
|
||||||
"""
|
"""
|
||||||
# right now this is only implicitly called by
|
# right now this is only implicitly called by
|
||||||
# streaming IPC but it should be called
|
# streaming IPC but it should be called
|
||||||
# to cancel any remotely spawned task
|
# to cancel any remotely spawned task
|
||||||
chan = ctx.chan
|
|
||||||
try:
|
try:
|
||||||
# this ctx based lookup ensures the requested task to
|
# this ctx based lookup ensures the requested task to
|
||||||
# be cancelled was indeed spawned by a request from this channel
|
# be cancelled was indeed spawned by a request from this channel
|
||||||
scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
|
scope, func, is_complete = self._rpc_tasks[(chan, cid)]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
log.warning(f"{cid} has already completed/terminated?")
|
log.warning(f"{cid} has already completed/terminated?")
|
||||||
return
|
return
|
||||||
|
@ -686,7 +696,7 @@ class Actor:
|
||||||
|
|
||||||
# don't allow cancelling this function mid-execution
|
# don't allow cancelling this function mid-execution
|
||||||
# (is this necessary?)
|
# (is this necessary?)
|
||||||
if func is self.cancel_task:
|
if func is self._cancel_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
scope.cancel()
|
scope.cancel()
|
||||||
|
@ -704,7 +714,7 @@ class Actor:
|
||||||
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
|
log.info(f"Cancelling all {len(tasks)} rpc tasks:\n{tasks} ")
|
||||||
for (chan, cid) in tasks.copy():
|
for (chan, cid) in tasks.copy():
|
||||||
# TODO: this should really done in a nursery batch
|
# TODO: this should really done in a nursery batch
|
||||||
await self.cancel_task(cid, Context(chan, cid))
|
await self._cancel_task(cid, chan)
|
||||||
# if tasks:
|
# if tasks:
|
||||||
log.info(
|
log.info(
|
||||||
f"Waiting for remaining rpc tasks to complete {tasks}")
|
f"Waiting for remaining rpc tasks to complete {tasks}")
|
||||||
|
@ -735,6 +745,25 @@ class Actor:
|
||||||
"""Return all channels to the actor with provided uid."""
|
"""Return all channels to the actor with provided uid."""
|
||||||
return self._peers[uid]
|
return self._peers[uid]
|
||||||
|
|
||||||
|
async def _do_handshake(
|
||||||
|
self,
|
||||||
|
chan: Channel
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Exchange (name, UUIDs) identifiers as the first communication step.
|
||||||
|
|
||||||
|
These are essentially the "mailbox addresses" found in actor model
|
||||||
|
parlance.
|
||||||
|
"""
|
||||||
|
await chan.send(self.uid)
|
||||||
|
uid: Tuple[str, str] = await chan.recv()
|
||||||
|
|
||||||
|
if not isinstance(uid, tuple):
|
||||||
|
raise ValueError(f"{uid} is not a valid uid?!")
|
||||||
|
|
||||||
|
chan.uid = uid
|
||||||
|
log.info(f"Handshake with actor {uid}@{chan.raddr} complete")
|
||||||
|
return uid
|
||||||
|
|
||||||
|
|
||||||
class Arbiter(Actor):
|
class Arbiter(Actor):
|
||||||
"""A special actor who knows all the other actors and always has
|
"""A special actor who knows all the other actors and always has
|
||||||
|
|
|
@ -215,10 +215,7 @@ class Context:
|
||||||
"""
|
"""
|
||||||
chan: Channel
|
chan: Channel
|
||||||
cid: str
|
cid: str
|
||||||
|
cancel_scope: trio.CancelScope
|
||||||
# TODO: we should probably attach the actor-task
|
|
||||||
# cancel scope here now that trio is exposing it
|
|
||||||
# as a public object
|
|
||||||
|
|
||||||
async def send_yield(self, data: Any) -> None:
|
async def send_yield(self, data: Any) -> None:
|
||||||
await self.chan.send({'yield': data, 'cid': self.cid})
|
await self.chan.send({'yield': data, 'cid': self.cid})
|
||||||
|
|
|
@ -33,21 +33,6 @@ async def maybe_open_nursery(nursery: trio._core._run.Nursery = None):
|
||||||
yield nursery
|
yield nursery
|
||||||
|
|
||||||
|
|
||||||
async def _do_handshake(
|
|
||||||
actor: 'Actor', # type: ignore
|
|
||||||
chan: Channel
|
|
||||||
) -> Any:
|
|
||||||
await chan.send(actor.uid)
|
|
||||||
uid: Tuple[str, str] = await chan.recv()
|
|
||||||
|
|
||||||
if not isinstance(uid, tuple):
|
|
||||||
raise ValueError(f"{uid} is not a valid uid?!")
|
|
||||||
|
|
||||||
chan.uid = uid
|
|
||||||
log.info(f"Handshake with actor {uid}@{chan.raddr} complete")
|
|
||||||
return uid
|
|
||||||
|
|
||||||
|
|
||||||
class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||||
special behaviour for signalling stream termination across an
|
special behaviour for signalling stream termination across an
|
||||||
|
@ -95,8 +80,8 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
raise unpack_error(msg, self._portal.channel)
|
raise unpack_error(msg, self._portal.channel)
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
"""Cancel associate remote actor task on close
|
"""Cancel associated remote actor task and local memory channel
|
||||||
as well as the local memory channel.
|
on close.
|
||||||
"""
|
"""
|
||||||
if self._rx_chan._closed:
|
if self._rx_chan._closed:
|
||||||
log.warning(f"{self} is already closed")
|
log.warning(f"{self} is already closed")
|
||||||
|
@ -107,15 +92,10 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Cancelling stream {cid} to "
|
f"Cancelling stream {cid} to "
|
||||||
f"{self._portal.channel.uid}")
|
f"{self._portal.channel.uid}")
|
||||||
# TODO: yeah.. it'd be nice if this was just an
|
# NOTE: we're telling the far end actor to cancel a task
|
||||||
# async func on the far end. Gotta figure out a
|
# corresponding to *this actor*. The far end local channel
|
||||||
# better way then implicitly feeding the ctx
|
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||||
# to declaring functions; likely a decorator
|
await self._portal.run('self', '_cancel_task', cid=cid)
|
||||||
# system.
|
|
||||||
rchan = await self._portal.run(
|
|
||||||
'self', 'cancel_task', cid=cid)
|
|
||||||
async for _ in rchan:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if cs.cancelled_caught:
|
if cs.cancelled_caught:
|
||||||
# XXX: there's no way to know if the remote task was indeed
|
# XXX: there's no way to know if the remote task was indeed
|
||||||
|
@ -153,6 +133,7 @@ class Portal:
|
||||||
Tuple[str, Any, str, Dict[str, Any]]
|
Tuple[str, Any, str, Dict[str, Any]]
|
||||||
] = None
|
] = None
|
||||||
self._streams: Set[StreamReceiveChannel] = set()
|
self._streams: Set[StreamReceiveChannel] = set()
|
||||||
|
self.actor = current_actor()
|
||||||
|
|
||||||
async def _submit(
|
async def _submit(
|
||||||
self,
|
self,
|
||||||
|
@ -167,7 +148,7 @@ class Portal:
|
||||||
This is an async call.
|
This is an async call.
|
||||||
"""
|
"""
|
||||||
# ship a function call request to the remote actor
|
# ship a function call request to the remote actor
|
||||||
cid, recv_chan = await current_actor().send_cmd(
|
cid, recv_chan = await self.actor.send_cmd(
|
||||||
self.channel, ns, func, kwargs)
|
self.channel, ns, func, kwargs)
|
||||||
|
|
||||||
# wait on first response msg and handle (this should be
|
# wait on first response msg and handle (this should be
|
||||||
|
@ -345,7 +326,7 @@ async def open_portal(
|
||||||
was_connected = True
|
was_connected = True
|
||||||
|
|
||||||
if channel.uid is None:
|
if channel.uid is None:
|
||||||
await _do_handshake(actor, channel)
|
await actor._do_handshake(channel)
|
||||||
|
|
||||||
msg_loop_cs = await nursery.start(
|
msg_loop_cs = await nursery.start(
|
||||||
partial(
|
partial(
|
||||||
|
|
Loading…
Reference in New Issue