Compare commits

..

6 Commits

Author SHA1 Message Date
Tyler Goodlet 6120e99d7e Rename `._error` -> `._remote_ctx_error` 2023-01-30 14:13:43 -05:00
Tyler Goodlet 33b0e36ad6 Break loop after result retreival 2023-01-30 14:13:43 -05:00
Tyler Goodlet a71a958f54 Log context cancellation using `.cancel()` loglevel 2023-01-30 14:13:43 -05:00
Tyler Goodlet c9eb466d76 Use `MsgStream.subscribe()` in `Context.result()`
The case exists where there is multiple tasks consuming from an open
2-way stream created via `Context.open_stream()` where a sibling task is
pulling from the stream while some other task also calls `.result()`.
Previously the `.result()` call would consume (drain) stream messages
directly from the underlying mem chan which would mean any sibling task
would not receive those same messages. Instead, make `.result()` check
if a stream is open and instead consume (and discard) stream msgs using
a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all
interested tasks get copies of the same packets.
2023-01-30 14:13:42 -05:00
Tyler Goodlet f7a1f3832f Enable stream backpressure by default, add `MsgStream.ctx: Context` 2023-01-30 14:09:35 -05:00
Tyler Goodlet 3f2e33a120 Don't unset actor global on root teardown 2023-01-30 14:09:35 -05:00
3 changed files with 74 additions and 34 deletions

View File

@ -253,7 +253,6 @@ async def open_root_actor(
logger.cancel("Shutting down root actor") logger.cancel("Shutting down root actor")
await actor.cancel() await actor.cancel()
finally: finally:
_state._current_actor = None
logger.runtime("Root actor terminated") logger.runtime("Root actor terminated")

View File

@ -199,8 +199,8 @@ async def _invoke(
except BaseExceptionGroup: except BaseExceptionGroup:
# if a context error was set then likely # if a context error was set then likely
# thei multierror was raised due to that # thei multierror was raised due to that
if ctx._error is not None: if ctx._remote_ctx_error is not None:
raise ctx._error from None raise ctx._remote_ctx_error from None
raise raise

View File

@ -27,7 +27,8 @@ from typing import (
Optional, Optional,
Callable, Callable,
AsyncGenerator, AsyncGenerator,
AsyncIterator AsyncIterator,
TYPE_CHECKING,
) )
import warnings import warnings
@ -41,6 +42,10 @@ from .log import get_logger
from .trionics import broadcast_receiver, BroadcastReceiver from .trionics import broadcast_receiver, BroadcastReceiver
if TYPE_CHECKING:
from ._portal import Portal
log = get_logger(__name__) log = get_logger(__name__)
@ -70,7 +75,7 @@ class MsgStream(trio.abc.Channel):
''' '''
def __init__( def __init__(
self, self,
ctx: 'Context', # typing: ignore # noqa ctx: Context, # typing: ignore # noqa
rx_chan: trio.MemoryReceiveChannel, rx_chan: trio.MemoryReceiveChannel,
_broadcaster: Optional[BroadcastReceiver] = None, _broadcaster: Optional[BroadcastReceiver] = None,
@ -83,6 +88,9 @@ class MsgStream(trio.abc.Channel):
self._eoc: bool = False self._eoc: bool = False
self._closed: bool = False self._closed: bool = False
def ctx(self) -> Context:
return self._ctx
# delegate directly to underlying mem channel # delegate directly to underlying mem channel
def receive_nowait(self): def receive_nowait(self):
msg = self._rx_chan.receive_nowait() msg = self._rx_chan.receive_nowait()
@ -278,7 +286,6 @@ class MsgStream(trio.abc.Channel):
@asynccontextmanager @asynccontextmanager
async def subscribe( async def subscribe(
self, self,
) -> AsyncIterator[BroadcastReceiver]: ) -> AsyncIterator[BroadcastReceiver]:
''' '''
Allocate and return a ``BroadcastReceiver`` which delegates Allocate and return a ``BroadcastReceiver`` which delegates
@ -335,8 +342,8 @@ class MsgStream(trio.abc.Channel):
Send a message over this stream to the far end. Send a message over this stream to the far end.
''' '''
if self._ctx._error: if self._ctx._remote_ctx_error:
raise self._ctx._error # from None raise self._ctx._remote_ctx_error # from None
if self._closed: if self._closed:
raise trio.ClosedResourceError('This stream was already closed') raise trio.ClosedResourceError('This stream was already closed')
@ -375,9 +382,10 @@ class Context:
_remote_func_type: Optional[str] = None _remote_func_type: Optional[str] = None
# only set on the caller side # only set on the caller side
_portal: Optional['Portal'] = None # type: ignore # noqa _portal: Optional[Portal] = None # type: ignore # noqa
_stream: Optional[MsgStream] = None
_result: Optional[Any] = False _result: Optional[Any] = False
_error: Optional[BaseException] = None _remote_ctx_error: Optional[BaseException] = None
# status flags # status flags
_cancel_called: bool = False _cancel_called: bool = False
@ -390,7 +398,7 @@ class Context:
# only set on the callee side # only set on the callee side
_scope_nursery: Optional[trio.Nursery] = None _scope_nursery: Optional[trio.Nursery] = None
_backpressure: bool = False _backpressure: bool = True
async def send_yield(self, data: Any) -> None: async def send_yield(self, data: Any) -> None:
@ -435,21 +443,26 @@ class Context:
# (currently) that other portal APIs (``Portal.run()``, # (currently) that other portal APIs (``Portal.run()``,
# ``.run_in_actor()``) do their own error checking at the point # ``.run_in_actor()``) do their own error checking at the point
# of the call and result processing. # of the call and result processing.
log.error(
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
f'{msg["error"]["tb_str"]}'
)
error = unpack_error(msg, self.chan) error = unpack_error(msg, self.chan)
if ( if (
isinstance(error, ContextCancelled) and isinstance(error, ContextCancelled)
self._cancel_called
): ):
# this is an expected cancel request response message log.cancel(
# and we don't need to raise it in scope since it will f'Remote context error for {self.chan.uid}:{self.cid}:\n'
# potentially override a real error f'{msg["error"]["tb_str"]}'
return )
if self._cancel_called:
# this is an expected cancel request response message
# and we don't need to raise it in scope since it will
# potentially override a real error
return
else:
log.error(
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
f'{msg["error"]["tb_str"]}'
)
self._error = error self._remote_ctx_error = error
# TODO: tempted to **not** do this by-reraising in a # TODO: tempted to **not** do this by-reraising in a
# nursery and instead cancel a surrounding scope, detect # nursery and instead cancel a surrounding scope, detect
@ -457,7 +470,7 @@ class Context:
if self._scope_nursery: if self._scope_nursery:
async def raiser(): async def raiser():
raise self._error from None raise self._remote_ctx_error from None
# from trio.testing import wait_all_tasks_blocked # from trio.testing import wait_all_tasks_blocked
# await wait_all_tasks_blocked() # await wait_all_tasks_blocked()
@ -483,6 +496,7 @@ class Context:
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}') log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')
self._cancel_called = True self._cancel_called = True
ipc_broken: bool = False
if side == 'caller': if side == 'caller':
if not self._portal: if not self._portal:
@ -500,7 +514,14 @@ class Context:
# NOTE: we're telling the far end actor to cancel a task # NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel # corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly. # instance is passed to `Actor._cancel_task()` implicitly.
await self._portal.run_from_ns('self', '_cancel_task', cid=cid) try:
await self._portal.run_from_ns(
'self',
'_cancel_task',
cid=cid,
)
except trio.BrokenResourceError:
ipc_broken = True
if cs.cancelled_caught: if cs.cancelled_caught:
# XXX: there's no way to know if the remote task was indeed # XXX: there's no way to know if the remote task was indeed
@ -516,7 +537,10 @@ class Context:
"Timed out on cancelling remote task " "Timed out on cancelling remote task "
f"{cid} for {self._portal.channel.uid}") f"{cid} for {self._portal.channel.uid}")
# callee side remote task elif ipc_broken:
log.cancel(
"Transport layer was broken before cancel request "
f"{cid} for {self._portal.channel.uid}")
else: else:
self._cancel_msg = msg self._cancel_msg = msg
@ -604,6 +628,7 @@ class Context:
ctx=self, ctx=self,
rx_chan=ctx._recv_chan, rx_chan=ctx._recv_chan,
) as stream: ) as stream:
self._stream = stream
if self._portal: if self._portal:
self._portal._streams.add(stream) self._portal._streams.add(stream)
@ -645,25 +670,22 @@ class Context:
if not self._recv_chan._closed: # type: ignore if not self._recv_chan._closed: # type: ignore
# wait for a final context result consuming def consume(
# and discarding any bi dir stream msgs still msg: dict,
# in transit from the far end.
while True:
msg = await self._recv_chan.receive() ) -> Optional[dict]:
try: try:
self._result = msg['return'] return msg['return']
break
except KeyError as msgerr: except KeyError as msgerr:
if 'yield' in msg: if 'yield' in msg:
# far end task is still streaming to us so discard # far end task is still streaming to us so discard
log.warning(f'Discarding stream delivered {msg}') log.warning(f'Discarding stream delivered {msg}')
continue return
elif 'stop' in msg: elif 'stop' in msg:
log.debug('Remote stream terminated') log.debug('Remote stream terminated')
continue return
# internal error should never get here # internal error should never get here
assert msg.get('cid'), ( assert msg.get('cid'), (
@ -673,6 +695,25 @@ class Context:
msg, self._portal.channel msg, self._portal.channel
) from msgerr ) from msgerr
# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
if self._stream:
async with self._stream.subscribe() as bstream:
async for msg in bstream:
result = consume(msg)
if result:
self._result = result
break
if not self._result:
while True:
msg = await self._recv_chan.receive()
result = consume(msg)
if result:
self._result = result
break
return self._result return self._result
async def started( async def started(