Compare commits

...

9 Commits

Author SHA1 Message Date
Tyler Goodlet dca0378598 Avoid mutate during interate error 2021-07-02 17:45:17 -04:00
Tyler Goodlet e09dacc2bb Expect context cancelled when we cancel 2021-07-02 17:45:17 -04:00
Tyler Goodlet aa6902fcfc Add pre-stream open error conditions 2021-07-02 17:45:17 -04:00
Tyler Goodlet 477d0bc697 De-densify some code 2021-07-02 17:45:17 -04:00
Tyler Goodlet 9d302069f3 Always shield cancel the caller on cancel-causing-errors, add teardown logging 2021-07-02 17:45:17 -04:00
Tyler Goodlet d00e2ca573 First try: pack cancelled tracebacks and ship to caller 2021-07-02 17:45:17 -04:00
Tyler Goodlet 25679ae22d Add temp warning msg for context cancel call 2021-07-02 17:45:17 -04:00
Tyler Goodlet 2b3cb042c6 Add some brief todo notes on idea of shielded breakpoint 2021-07-02 17:45:17 -04:00
Tyler Goodlet 807c1d71a2 Consider relaying context error via raised-in-scope-nursery task 2021-07-02 17:45:17 -04:00
8 changed files with 302 additions and 131 deletions

View File

@ -262,7 +262,7 @@ async def test_caller_closes_ctx_after_callee_opens_stream(
async with ctx.open_stream() as stream: async with ctx.open_stream() as stream:
async for msg in stream: async for msg in stream:
pass pass
except trio.ClosedResourceError: except tractor.ContextCancelled:
pass pass
else: else:
assert 0, "Should have received closed resource error?" assert 0, "Should have received closed resource error?"

View File

@ -30,7 +30,7 @@ async def publisher(
sub = 'even' if is_even(val) else 'odd' sub = 'even' if is_even(val) else 'odd'
for sub_stream in _registry[sub]: for sub_stream in _registry[sub].copy():
await sub_stream.send(val) await sub_stream.send(val)
# throttle send rate to ~1kHz # throttle send rate to ~1kHz

View File

@ -46,6 +46,7 @@ class ActorFailure(Exception):
async def _invoke( async def _invoke(
actor: 'Actor', actor: 'Actor',
cid: str, cid: str,
chan: Channel, chan: Channel,
@ -58,10 +59,15 @@ async def _invoke(
"""Invoke local func and deliver result(s) over provided channel. """Invoke local func and deliver result(s) over provided channel.
""" """
treat_as_gen = False treat_as_gen = False
cs = None
# possible a traceback (not sure what typing is for this..)
tb = None
cancel_scope = trio.CancelScope() cancel_scope = trio.CancelScope()
ctx = Context(chan, cid, _cancel_scope=cancel_scope) cs: trio.CancelScope = None
context = False
ctx = Context(chan, cid)
context: bool = False
if getattr(func, '_tractor_stream_function', False): if getattr(func, '_tractor_stream_function', False):
# handle decorated ``@tractor.stream`` async functions # handle decorated ``@tractor.stream`` async functions
@ -149,14 +155,34 @@ async def _invoke(
# context func with support for bi-dir streaming # context func with support for bi-dir streaming
await chan.send({'functype': 'context', 'cid': cid}) await chan.send({'functype': 'context', 'cid': cid})
with cancel_scope as cs: async with trio.open_nursery() as scope_nursery:
ctx._scope_nursery = scope_nursery
cs = scope_nursery.cancel_scope
task_status.started(cs) task_status.started(cs)
try:
await chan.send({'return': await coro, 'cid': cid}) await chan.send({'return': await coro, 'cid': cid})
except trio.Cancelled as err:
tb = err.__traceback__
if cs.cancelled_caught: if cs.cancelled_caught:
# TODO: pack in ``trio.Cancelled.__traceback__`` here
# so they can be unwrapped and displayed on the caller
# side!
fname = func.__name__
if ctx._cancel_called:
msg = f'{fname} cancelled itself'
elif cs.cancel_called:
msg = (
f'{fname} was remotely cancelled by its caller '
f'{ctx.chan.uid}'
)
# task-contex was cancelled so relay to the cancel to caller # task-contex was cancelled so relay to the cancel to caller
raise ContextCancelled( raise ContextCancelled(
f'{func.__name__} cancelled itself', msg,
suberror_type=trio.Cancelled, suberror_type=trio.Cancelled,
) )
@ -185,14 +211,16 @@ async def _invoke(
log.exception("Actor crashed:") log.exception("Actor crashed:")
# always ship errors back to caller # always ship errors back to caller
err_msg = pack_error(err) err_msg = pack_error(err, tb=tb)
err_msg['cid'] = cid err_msg['cid'] = cid
try: try:
await chan.send(err_msg) await chan.send(err_msg)
except trio.ClosedResourceError: except trio.ClosedResourceError:
log.warning( # if we can't propagate the error that's a big boo boo
f"Failed to ship error to caller @ {chan.uid}") log.error(
f"Failed to ship error to caller @ {chan.uid} !?"
)
if cs is None: if cs is None:
# error is from above code not from rpc invocation # error is from above code not from rpc invocation
@ -210,7 +238,7 @@ async def _invoke(
f"Task {func} likely errored or cancelled before it started") f"Task {func} likely errored or cancelled before it started")
finally: finally:
if not actor._rpc_tasks: if not actor._rpc_tasks:
log.info("All RPC tasks have completed") log.runtime("All RPC tasks have completed")
actor._ongoing_rpc_tasks.set() actor._ongoing_rpc_tasks.set()
@ -225,10 +253,10 @@ _lifetime_stack: ExitStack = ExitStack()
class Actor: class Actor:
"""The fundamental concurrency primitive. """The fundamental concurrency primitive.
An *actor* is the combination of a regular Python or An *actor* is the combination of a regular Python process
``multiprocessing.Process`` executing a ``trio`` task tree, communicating executing a ``trio`` task tree, communicating
with other actors through "portals" which provide a native async API with other actors through "portals" which provide a native async API
around "channels". around various IPC transport "channels".
""" """
is_arbiter: bool = False is_arbiter: bool = False
@ -372,14 +400,18 @@ class Actor:
raise mne raise mne
async def _stream_handler( async def _stream_handler(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
) -> None: ) -> None:
"""Entry point for new inbound connections to the channel server. """Entry point for new inbound connections to the channel server.
""" """
self._no_more_peers = trio.Event() # unset self._no_more_peers = trio.Event() # unset
chan = Channel(stream=stream) chan = Channel(stream=stream)
log.info(f"New connection to us {chan}") log.runtime(f"New connection to us {chan}")
# send/receive initial handshake response # send/receive initial handshake response
try: try:
@ -410,8 +442,12 @@ class Actor:
event.set() event.set()
chans = self._peers[uid] chans = self._peers[uid]
# TODO: re-use channels for new connections instead
# of always new ones; will require changing all the
# discovery funcs
if chans: if chans:
log.warning( log.runtime(
f"already have channel(s) for {uid}:{chans}?" f"already have channel(s) for {uid}:{chans}?"
) )
log.trace(f"Registered {chan} for {uid}") # type: ignore log.trace(f"Registered {chan} for {uid}") # type: ignore
@ -423,10 +459,24 @@ class Actor:
try: try:
await self._process_messages(chan) await self._process_messages(chan)
finally: finally:
# channel cleanup sequence
# for (channel, cid) in self._rpc_tasks.copy():
# if channel is chan:
# with trio.CancelScope(shield=True):
# await self._cancel_task(cid, channel)
# # close all consumer side task mem chans
# send_chan, _ = self._cids2qs[(chan.uid, cid)]
# assert send_chan.cid == cid # type: ignore
# await send_chan.aclose()
# Drop ref to channel so it can be gc-ed and disconnected # Drop ref to channel so it can be gc-ed and disconnected
log.debug(f"Releasing channel {chan} from {chan.uid}") log.debug(f"Releasing channel {chan} from {chan.uid}")
chans = self._peers.get(chan.uid) chans = self._peers.get(chan.uid)
chans.remove(chan) chans.remove(chan)
if not chans: if not chans:
log.debug(f"No more channels for {chan.uid}") log.debug(f"No more channels for {chan.uid}")
self._peers.pop(chan.uid, None) self._peers.pop(chan.uid, None)
@ -439,14 +489,22 @@ class Actor:
# # XXX: is this necessary (GC should do it?) # # XXX: is this necessary (GC should do it?)
if chan.connected(): if chan.connected():
# if the channel is still connected it may mean the far
# end has not closed and we may have gotten here due to
# an error and so we should at least try to terminate
# the channel from this end gracefully.
log.debug(f"Disconnecting channel {chan}") log.debug(f"Disconnecting channel {chan}")
try: try:
# send our msg loop terminate sentinel # send a msg loop terminate sentinel
await chan.send(None) await chan.send(None)
# XXX: do we want this?
# causes "[104] connection reset by peer" on other end
# await chan.aclose() # await chan.aclose()
except trio.BrokenResourceError: except trio.BrokenResourceError:
log.exception( log.warning(f"Channel for {chan.uid} was already closed")
f"Channel for {chan.uid} was already zonked..")
async def _push_result( async def _push_result(
self, self,
@ -456,18 +514,22 @@ class Actor:
) -> None: ) -> None:
"""Push an RPC result to the local consumer's queue. """Push an RPC result to the local consumer's queue.
""" """
actorid = chan.uid # actorid = chan.uid
assert actorid, f"`actorid` can't be {actorid}" assert chan.uid, f"`chan.uid` can't be {chan.uid}"
send_chan, recv_chan = self._cids2qs[(actorid, cid)] send_chan, recv_chan = self._cids2qs[(chan.uid, cid)]
assert send_chan.cid == cid # type: ignore assert send_chan.cid == cid # type: ignore
# if 'stop' in msg: if 'error' in msg:
ctx = getattr(recv_chan, '_ctx', None)
# if ctx:
# ctx._error_from_remote_msg(msg)
# log.debug(f"{send_chan} was terminated at remote end") # log.debug(f"{send_chan} was terminated at remote end")
# # indicate to consumer that far end has stopped # # indicate to consumer that far end has stopped
# return await send_chan.aclose() # return await send_chan.aclose()
try: try:
log.debug(f"Delivering {msg} from {actorid} to caller {cid}") log.debug(f"Delivering {msg} from {chan.uid} to caller {cid}")
# maintain backpressure # maintain backpressure
await send_chan.send(msg) await send_chan.send(msg)
@ -486,7 +548,9 @@ class Actor:
self, self,
actorid: Tuple[str, str], actorid: Tuple[str, str],
cid: str cid: str
) -> Tuple[trio.abc.SendChannel, trio.abc.ReceiveChannel]: ) -> Tuple[trio.abc.SendChannel, trio.abc.ReceiveChannel]:
log.debug(f"Getting result queue for {actorid} cid {cid}") log.debug(f"Getting result queue for {actorid} cid {cid}")
try: try:
send_chan, recv_chan = self._cids2qs[(actorid, cid)] send_chan, recv_chan = self._cids2qs[(actorid, cid)]
@ -548,9 +612,15 @@ class Actor:
if channel is chan: if channel is chan:
await self._cancel_task(cid, channel) await self._cancel_task(cid, channel)
# close all consumer side task mem chans
# send_chan, _ = self._cids2qs[(chan.uid, cid)]
# assert send_chan.cid == cid # type: ignore
# await send_chan.aclose()
log.debug( log.debug(
f"Msg loop signalled to terminate for" f"Msg loop signalled to terminate for"
f" {chan} from {chan.uid}") f" {chan} from {chan.uid}")
break break
log.trace( # type: ignore log.trace( # type: ignore
@ -621,7 +691,7 @@ class Actor:
else: else:
# mark that we have ongoing rpc tasks # mark that we have ongoing rpc tasks
self._ongoing_rpc_tasks = trio.Event() self._ongoing_rpc_tasks = trio.Event()
log.info(f"RPC func is {func}") log.runtime(f"RPC func is {func}")
# store cancel scope such that the rpc task can be # store cancel scope such that the rpc task can be
# cancelled gracefully if requested # cancelled gracefully if requested
self._rpc_tasks[(chan, cid)] = ( self._rpc_tasks[(chan, cid)] = (
@ -630,7 +700,7 @@ class Actor:
# self.cancel() was called so kill this msg loop # self.cancel() was called so kill this msg loop
# and break out into ``_async_main()`` # and break out into ``_async_main()``
log.warning( log.warning(
f"{self.uid} was remotely cancelled; " f"Actor {self.uid} was remotely cancelled; "
"waiting on cancellation completion..") "waiting on cancellation completion..")
await self._cancel_complete.wait() await self._cancel_complete.wait()
loop_cs.cancel() loop_cs.cancel()
@ -648,17 +718,13 @@ class Actor:
except ( except (
TransportClosed, TransportClosed,
trio.BrokenResourceError, trio.BrokenResourceError,
trio.ClosedResourceError # trio.ClosedResourceError
): ):
# channels "breaking" is ok since we don't have a teardown # channels "breaking" is ok since we don't have a teardown
# handshake for them (yet) and instead we simply bail out # handshake for them (yet) and instead we simply bail out
# of the message loop and expect the teardown sequence # of the message loop and expect the surrounding
# to clean up. # caller's teardown sequence to clean up.
log.error(f"{chan} form {chan.uid} closed abruptly") log.warning(f"Channel from {chan.uid} closed abruptly")
# raise
except trio.ClosedResourceError:
log.error(f"{chan} form {chan.uid} broke")
except (Exception, trio.MultiError) as err: except (Exception, trio.MultiError) as err:
# ship any "internal" exception (i.e. one from internal machinery # ship any "internal" exception (i.e. one from internal machinery
@ -1102,7 +1168,7 @@ class Actor:
raise ValueError(f"{uid} is not a valid uid?!") raise ValueError(f"{uid} is not a valid uid?!")
chan.uid = uid chan.uid = uid
log.info(f"Handshake with actor {uid}@{chan.raddr} complete") log.runtime(f"Handshake with actor {uid}@{chan.raddr} complete")
return uid return uid

View File

@ -207,11 +207,24 @@ async def _hijack_stdin_relay_to_child(
return "pdb_unlock_complete" return "pdb_unlock_complete"
async def _breakpoint(debug_func) -> None: async def _breakpoint(
"""``tractor`` breakpoint entry for engaging pdb machinery
in subactors. debug_func,
# TODO:
# shield: bool = False
) -> None:
'''``tractor`` breakpoint entry for engaging pdb machinery
in the root or a subactor.
'''
# TODO: is it possible to debug a trio.Cancelled except block?
# right now it seems like we can kinda do with by shielding
# around ``tractor.breakpoint()`` but not if we move the shielded
# scope here???
# with trio.CancelScope(shield=shield):
"""
actor = tractor.current_actor() actor = tractor.current_actor()
task_name = trio.lowlevel.current_task().name task_name = trio.lowlevel.current_task().name

View File

@ -16,12 +16,14 @@ from ._state import current_actor, _runtime_vars
@asynccontextmanager @asynccontextmanager
async def get_arbiter( async def get_arbiter(
host: str, host: str,
port: int, port: int,
) -> typing.AsyncGenerator[Union[Portal, LocalPortal], None]: ) -> typing.AsyncGenerator[Union[Portal, LocalPortal], None]:
"""Return a portal instance connected to a local or remote '''Return a portal instance connected to a local or remote
arbiter. arbiter.
""" '''
actor = current_actor() actor = current_actor()
if not actor: if not actor:
@ -33,16 +35,20 @@ async def get_arbiter(
yield LocalPortal(actor, Channel((host, port))) yield LocalPortal(actor, Channel((host, port)))
else: else:
async with _connect_chan(host, port) as chan: async with _connect_chan(host, port) as chan:
async with open_portal(chan) as arb_portal: async with open_portal(chan) as arb_portal:
yield arb_portal yield arb_portal
@asynccontextmanager @asynccontextmanager
async def get_root( async def get_root(
**kwargs, **kwargs,
) -> typing.AsyncGenerator[Union[Portal, LocalPortal], None]: ) -> typing.AsyncGenerator[Union[Portal, LocalPortal], None]:
host, port = _runtime_vars['_root_mailbox'] host, port = _runtime_vars['_root_mailbox']
assert host is not None assert host is not None
async with _connect_chan(host, port) as chan: async with _connect_chan(host, port) as chan:
async with open_portal(chan, **kwargs) as portal: async with open_portal(chan, **kwargs) as portal:
yield portal yield portal
@ -60,12 +66,16 @@ async def find_actor(
""" """
actor = current_actor() actor = current_actor()
async with get_arbiter(*arbiter_sockaddr or actor._arb_addr) as arb_portal: async with get_arbiter(*arbiter_sockaddr or actor._arb_addr) as arb_portal:
sockaddr = await arb_portal.run_from_ns('self', 'find_actor', name=name) sockaddr = await arb_portal.run_from_ns('self', 'find_actor', name=name)
# TODO: return portals to all available actors - for now just # TODO: return portals to all available actors - for now just
# the last one that registered # the last one that registered
if name == 'arbiter' and actor.is_arbiter: if name == 'arbiter' and actor.is_arbiter:
raise RuntimeError("The current actor is the arbiter") raise RuntimeError("The current actor is the arbiter")
elif sockaddr: elif sockaddr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(*sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal
@ -83,9 +93,12 @@ async def wait_for_actor(
A portal to the first registered actor is returned. A portal to the first registered actor is returned.
""" """
actor = current_actor() actor = current_actor()
async with get_arbiter(*arbiter_sockaddr or actor._arb_addr) as arb_portal: async with get_arbiter(*arbiter_sockaddr or actor._arb_addr) as arb_portal:
sockaddrs = await arb_portal.run_from_ns('self', 'wait_for_actor', name=name) sockaddrs = await arb_portal.run_from_ns('self', 'wait_for_actor', name=name)
sockaddr = sockaddrs[-1] sockaddr = sockaddrs[-1]
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(*sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal

View File

@ -56,13 +56,22 @@ class NoRuntime(RuntimeError):
"The root actor has not been initialized yet" "The root actor has not been initialized yet"
def pack_error(exc: BaseException) -> Dict[str, Any]: def pack_error(
exc: BaseException,
tb = None,
) -> Dict[str, Any]:
"""Create an "error message" for tranmission over """Create an "error message" for tranmission over
a channel (aka the wire). a channel (aka the wire).
""" """
if tb:
tb_str = ''.join(traceback.format_tb(tb))
else:
tb_str = traceback.format_exc()
return { return {
'error': { 'error': {
'tb_str': traceback.format_exc(), 'tb_str': tb_str,
'type_str': type(exc).__name__, 'type_str': type(exc).__name__,
} }
} }

View File

@ -177,6 +177,7 @@ class Portal:
f"Cancelling all streams with {self.channel.uid}") f"Cancelling all streams with {self.channel.uid}")
for stream in self._streams.copy(): for stream in self._streams.copy():
try: try:
# with trio.CancelScope(shield=True):
await stream.aclose() await stream.aclose()
except trio.ClosedResourceError: except trio.ClosedResourceError:
# don't error the stream having already been closed # don't error the stream having already been closed
@ -294,6 +295,7 @@ class Portal:
self, self,
async_gen_func: Callable, # typing: ignore async_gen_func: Callable, # typing: ignore
**kwargs, **kwargs,
) -> AsyncGenerator[ReceiveMsgStream, None]: ) -> AsyncGenerator[ReceiveMsgStream, None]:
if not inspect.isasyncgenfunction(async_gen_func): if not inspect.isasyncgenfunction(async_gen_func):
@ -346,7 +348,6 @@ class Portal:
self, self,
func: Callable, func: Callable,
cancel_on_exit: bool = False,
**kwargs, **kwargs,
) -> AsyncGenerator[Tuple[Context, Any], None]: ) -> AsyncGenerator[Tuple[Context, Any], None]:
@ -358,6 +359,7 @@ class Portal:
and synchronized final result collection. See ``tractor.Context``. and synchronized final result collection. See ``tractor.Context``.
''' '''
# conduct target func method structural checks # conduct target func method structural checks
if not inspect.iscoroutinefunction(func) and ( if not inspect.iscoroutinefunction(func) and (
getattr(func, '_tractor_contex_function', False) getattr(func, '_tractor_contex_function', False)
@ -369,7 +371,6 @@ class Portal:
recv_chan: Optional[trio.MemoryReceiveChannel] = None recv_chan: Optional[trio.MemoryReceiveChannel] = None
try:
cid, recv_chan, functype, first_msg = await self._submit( cid, recv_chan, functype, first_msg = await self._submit(
fn_mod_path, fn_name, kwargs) fn_mod_path, fn_name, kwargs)
@ -390,44 +391,81 @@ class Portal:
else: else:
raise raise
_err = None
# deliver context instance and .started() msg value in open # deliver context instance and .started() msg value in open
# tuple. # tuple.
try:
async with trio.open_nursery() as scope_nursery:
ctx = Context( ctx = Context(
self.channel, self.channel,
cid, cid,
_portal=self, _portal=self,
_recv_chan=recv_chan, _recv_chan=recv_chan,
_scope_nursery=scope_nursery,
) )
recv_chan._ctx = ctx
try: # await trio.lowlevel.checkpoint()
yield ctx, first yield ctx, first
if cancel_on_exit: # if not ctx._cancel_called:
await ctx.cancel() # await ctx.result()
else: # await recv_chan.aclose()
except ContextCancelled as err:
_err = err
if not ctx._cancel_called: if not ctx._cancel_called:
await ctx.result() # context was cancelled at the far end but was
# not part of this end requesting that cancel
# so raise for the local task to respond and handle.
raise
except ContextCancelled:
# if the context was cancelled by client code # if the context was cancelled by client code
# then we don't need to raise since user code # then we don't need to raise since user code
# is expecting this. # is expecting this and the block should exit.
if not ctx._cancel_called: else:
raise log.debug(f'Context {ctx} cancelled gracefully')
except BaseException: except (
# the context cancels itself on any deviation trio.Cancelled,
trio.MultiError,
Exception,
) as err:
_err = err
# the context cancels itself on any cancel
# causing error.
log.error(f'Context {ctx} sending cancel to far end')
with trio.CancelScope(shield=True):
await ctx.cancel() await ctx.cancel()
raise raise
finally: finally:
log.info(f'Context for {func.__name__} completed') result = await ctx.result()
finally: # though it should be impossible for any tasks
# operating *in* this scope to have survived
# we tear down the runtime feeder chan last
# to avoid premature stream clobbers.
if recv_chan is not None: if recv_chan is not None:
await recv_chan.aclose() await recv_chan.aclose()
if _err:
if ctx._cancel_called:
log.warning(
f'Context {fn_name} cancelled by caller with\n{_err}'
)
elif _err is not None:
log.warning(
f'Context {fn_name} cancelled by callee with\n{_err}'
)
else:
log.info(
f'Context {fn_name} returned '
f'value from callee `{self._result}`'
)
@dataclass @dataclass
class LocalPortal: class LocalPortal:
"""A 'portal' to a local ``Actor``. """A 'portal' to a local ``Actor``.
@ -450,10 +488,12 @@ class LocalPortal:
@asynccontextmanager @asynccontextmanager
async def open_portal( async def open_portal(
channel: Channel, channel: Channel,
nursery: Optional[trio.Nursery] = None, nursery: Optional[trio.Nursery] = None,
start_msg_loop: bool = True, start_msg_loop: bool = True,
shield: bool = False, shield: bool = False,
) -> AsyncGenerator[Portal, None]: ) -> AsyncGenerator[Portal, None]:
"""Open a ``Portal`` through the provided ``channel``. """Open a ``Portal`` through the provided ``channel``.
@ -464,6 +504,7 @@ async def open_portal(
was_connected = False was_connected = False
async with maybe_open_nursery(nursery, shield=shield) as nursery: async with maybe_open_nursery(nursery, shield=shield) as nursery:
if not channel.connected(): if not channel.connected():
await channel.connect() await channel.connect()
was_connected = True was_connected = True
@ -485,12 +526,14 @@ async def open_portal(
portal = Portal(channel) portal = Portal(channel)
try: try:
yield portal yield portal
finally: finally:
await portal.aclose() await portal.aclose()
if was_connected: if was_connected:
# cancel remote channel-msg loop # gracefully signal remote channel-msg loop
await channel.send(None) await channel.send(None)
# await channel.aclose()
# cancel background msg loop task # cancel background msg loop task
if msg_loop_cs: if msg_loop_cs:

View File

@ -7,7 +7,7 @@ from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Iterator, Optional, Callable, Any, Iterator, Optional, Callable,
AsyncGenerator, AsyncGenerator, Dict,
) )
import warnings import warnings
@ -15,7 +15,7 @@ import warnings
import trio import trio
from ._ipc import Channel from ._ipc import Channel
from ._exceptions import unpack_error from ._exceptions import unpack_error, ContextCancelled
from ._state import current_actor from ._state import current_actor
from .log import get_logger from .log import get_logger
@ -67,6 +67,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise trio.EndOfChannel raise trio.EndOfChannel
try: try:
msg = await self._rx_chan.receive() msg = await self._rx_chan.receive()
return msg['yield'] return msg['yield']
@ -134,11 +135,6 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise # propagate raise # propagate
# except trio.Cancelled:
# # relay cancels to the remote task
# await self.aclose()
# raise
@contextmanager @contextmanager
def shield( def shield(
self self
@ -212,7 +208,10 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# stop from the caller side # stop from the caller side
await self._ctx.send_stop() await self._ctx.send_stop()
except trio.BrokenResourceError: except (
trio.BrokenResourceError,
trio.ClosedResourceError
):
# the underlying channel may already have been pulled # the underlying channel may already have been pulled
# in which case our stop message is meaningless since # in which case our stop message is meaningless since
# it can't traverse the transport. # it can't traverse the transport.
@ -254,18 +253,6 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# still need to consume msgs that are "in transit" from the far # still need to consume msgs that are "in transit" from the far
# end (eg. for ``Context.result()``). # end (eg. for ``Context.result()``).
# TODO: but make it broadcasting to consumers
# def clone(self):
# """Clone this receive channel allowing for multi-task
# consumption from the same channel.
# """
# return ReceiveStream(
# self._cid,
# self._rx_chan.clone(),
# self._portal,
# )
class MsgStream(ReceiveMsgStream, trio.abc.Channel): class MsgStream(ReceiveMsgStream, trio.abc.Channel):
""" """
@ -282,6 +269,17 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
''' '''
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
# TODO: but make it broadcasting to consumers
def clone(self):
"""Clone this receive channel allowing for multi-task
consumption from the same channel.
"""
return MsgStream(
self._ctx,
self._rx_chan.clone(),
)
@dataclass @dataclass
class Context: class Context:
@ -308,7 +306,7 @@ class Context:
_cancel_called: bool = False _cancel_called: bool = False
# only set on the callee side # only set on the callee side
_cancel_scope: Optional[trio.CancelScope] = None _scope_nursery: Optional[trio.Nursery] = None
async def send_yield(self, data: Any) -> None: async def send_yield(self, data: Any) -> None:
@ -323,6 +321,22 @@ class Context:
async def send_stop(self) -> None: async def send_stop(self) -> None:
await self.chan.send({'stop': True, 'cid': self.cid}) await self.chan.send({'stop': True, 'cid': self.cid})
def _error_from_remote_msg(
self,
msg: Dict[str, Any],
) -> None:
'''Unpack and raise a msg error into the local scope
nursery for this context.
Acts as a form of "relay" for a remote error raised
in the corresponding remote callee task.
'''
async def raiser():
raise unpack_error(msg, self.chan)
self._scope_nursery.start_soon(raiser)
async def cancel(self) -> None: async def cancel(self) -> None:
'''Cancel this inter-actor-task context. '''Cancel this inter-actor-task context.
@ -330,9 +344,13 @@ class Context:
Timeout quickly in an attempt to sidestep 2-generals... Timeout quickly in an attempt to sidestep 2-generals...
''' '''
side = 'caller' if self._portal else 'callee'
log.warning(f'Cancelling {side} side of context to {self.chan}')
self._cancel_called = True self._cancel_called = True
if self._portal: # caller side: if side == 'caller':
if not self._portal: if not self._portal:
raise RuntimeError( raise RuntimeError(
"No portal found, this is likely a callee side context" "No portal found, this is likely a callee side context"
@ -360,14 +378,17 @@ class Context:
"May have failed to cancel remote task " "May have failed to cancel remote task "
f"{cid} for {self._portal.channel.uid}") f"{cid} for {self._portal.channel.uid}")
else: else:
# ensure callee side # callee side remote task
assert self._cancel_scope
# TODO: should we have an explicit cancel message # TODO: should we have an explicit cancel message
# or is relaying the local `trio.Cancelled` as an # or is relaying the local `trio.Cancelled` as an
# {'error': trio.Cancelled, cid: "blah"} enough? # {'error': trio.Cancelled, cid: "blah"} enough?
# This probably gets into the discussion in # This probably gets into the discussion in
# https://github.com/goodboy/tractor/issues/36 # https://github.com/goodboy/tractor/issues/36
self._cancel_scope.cancel() self._scope_nursery.cancel_scope.cancel()
if self._recv_chan:
await self._recv_chan.aclose()
@asynccontextmanager @asynccontextmanager
async def open_stream( async def open_stream(
@ -409,19 +430,25 @@ class Context:
self.cid self.cid
) )
# XXX: If the underlying receive mem chan has been closed then # Likewise if the surrounding context has been cancelled we error here
# likely client code has already exited a ``.open_stream()`` # since it likely means the surrounding block was exited or
# block prior. we error here until such a time that we decide # killed
# allowing streams to be "re-connected" is supported and/or
# a good idea. if self._cancel_called:
if recv_chan._closed:
task = trio.lowlevel.current_task().name task = trio.lowlevel.current_task().name
raise trio.ClosedResourceError( raise ContextCancelled(
f'stream for {actor.uid[0]}:{task} has already been closed.' f'Context around {actor.uid[0]}:{task} was already cancelled!'
'\nRe-opening a closed stream is not yet supported!'
'\nConsider re-calling the containing `@tractor.context` func'
) )
# XXX: If the underlying channel feeder receive mem chan has
# been closed then likely client code has already exited
# a ``.open_stream()`` block prior or there was some other
# unanticipated error or cancellation from ``trio``.
if recv_chan._closed:
raise trio.ClosedResourceError(
'The underlying channel for this stream was already closed!?')
async with MsgStream( async with MsgStream(
ctx=self, ctx=self,
rx_chan=recv_chan, rx_chan=recv_chan,