forked from goodboy/tractor
Compare commits
6 Commits
master
...
ctx_result
Author | SHA1 | Date |
---|---|---|
Tyler Goodlet | 6120e99d7e | |
Tyler Goodlet | 33b0e36ad6 | |
Tyler Goodlet | a71a958f54 | |
Tyler Goodlet | c9eb466d76 | |
Tyler Goodlet | f7a1f3832f | |
Tyler Goodlet | 3f2e33a120 |
|
@ -253,7 +253,6 @@ async def open_root_actor(
|
||||||
logger.cancel("Shutting down root actor")
|
logger.cancel("Shutting down root actor")
|
||||||
await actor.cancel()
|
await actor.cancel()
|
||||||
finally:
|
finally:
|
||||||
_state._current_actor = None
|
|
||||||
logger.runtime("Root actor terminated")
|
logger.runtime("Root actor terminated")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -199,8 +199,8 @@ async def _invoke(
|
||||||
except BaseExceptionGroup:
|
except BaseExceptionGroup:
|
||||||
# if a context error was set then likely
|
# if a context error was set then likely
|
||||||
# thei multierror was raised due to that
|
# thei multierror was raised due to that
|
||||||
if ctx._error is not None:
|
if ctx._remote_ctx_error is not None:
|
||||||
raise ctx._error from None
|
raise ctx._remote_ctx_error from None
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,8 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Callable,
|
Callable,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
AsyncIterator
|
AsyncIterator,
|
||||||
|
TYPE_CHECKING,
|
||||||
)
|
)
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -41,6 +42,10 @@ from .log import get_logger
|
||||||
from .trionics import broadcast_receiver, BroadcastReceiver
|
from .trionics import broadcast_receiver, BroadcastReceiver
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ._portal import Portal
|
||||||
|
|
||||||
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +75,7 @@ class MsgStream(trio.abc.Channel):
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ctx: 'Context', # typing: ignore # noqa
|
ctx: Context, # typing: ignore # noqa
|
||||||
rx_chan: trio.MemoryReceiveChannel,
|
rx_chan: trio.MemoryReceiveChannel,
|
||||||
_broadcaster: Optional[BroadcastReceiver] = None,
|
_broadcaster: Optional[BroadcastReceiver] = None,
|
||||||
|
|
||||||
|
@ -83,6 +88,9 @@ class MsgStream(trio.abc.Channel):
|
||||||
self._eoc: bool = False
|
self._eoc: bool = False
|
||||||
self._closed: bool = False
|
self._closed: bool = False
|
||||||
|
|
||||||
|
def ctx(self) -> Context:
|
||||||
|
return self._ctx
|
||||||
|
|
||||||
# delegate directly to underlying mem channel
|
# delegate directly to underlying mem channel
|
||||||
def receive_nowait(self):
|
def receive_nowait(self):
|
||||||
msg = self._rx_chan.receive_nowait()
|
msg = self._rx_chan.receive_nowait()
|
||||||
|
@ -278,7 +286,6 @@ class MsgStream(trio.abc.Channel):
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
) -> AsyncIterator[BroadcastReceiver]:
|
) -> AsyncIterator[BroadcastReceiver]:
|
||||||
'''
|
'''
|
||||||
Allocate and return a ``BroadcastReceiver`` which delegates
|
Allocate and return a ``BroadcastReceiver`` which delegates
|
||||||
|
@ -335,8 +342,8 @@ class MsgStream(trio.abc.Channel):
|
||||||
Send a message over this stream to the far end.
|
Send a message over this stream to the far end.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if self._ctx._error:
|
if self._ctx._remote_ctx_error:
|
||||||
raise self._ctx._error # from None
|
raise self._ctx._remote_ctx_error # from None
|
||||||
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise trio.ClosedResourceError('This stream was already closed')
|
raise trio.ClosedResourceError('This stream was already closed')
|
||||||
|
@ -375,9 +382,10 @@ class Context:
|
||||||
_remote_func_type: Optional[str] = None
|
_remote_func_type: Optional[str] = None
|
||||||
|
|
||||||
# only set on the caller side
|
# only set on the caller side
|
||||||
_portal: Optional['Portal'] = None # type: ignore # noqa
|
_portal: Optional[Portal] = None # type: ignore # noqa
|
||||||
|
_stream: Optional[MsgStream] = None
|
||||||
_result: Optional[Any] = False
|
_result: Optional[Any] = False
|
||||||
_error: Optional[BaseException] = None
|
_remote_ctx_error: Optional[BaseException] = None
|
||||||
|
|
||||||
# status flags
|
# status flags
|
||||||
_cancel_called: bool = False
|
_cancel_called: bool = False
|
||||||
|
@ -390,7 +398,7 @@ class Context:
|
||||||
# only set on the callee side
|
# only set on the callee side
|
||||||
_scope_nursery: Optional[trio.Nursery] = None
|
_scope_nursery: Optional[trio.Nursery] = None
|
||||||
|
|
||||||
_backpressure: bool = False
|
_backpressure: bool = True
|
||||||
|
|
||||||
async def send_yield(self, data: Any) -> None:
|
async def send_yield(self, data: Any) -> None:
|
||||||
|
|
||||||
|
@ -435,21 +443,26 @@ class Context:
|
||||||
# (currently) that other portal APIs (``Portal.run()``,
|
# (currently) that other portal APIs (``Portal.run()``,
|
||||||
# ``.run_in_actor()``) do their own error checking at the point
|
# ``.run_in_actor()``) do their own error checking at the point
|
||||||
# of the call and result processing.
|
# of the call and result processing.
|
||||||
log.error(
|
error = unpack_error(msg, self.chan)
|
||||||
|
if (
|
||||||
|
isinstance(error, ContextCancelled)
|
||||||
|
):
|
||||||
|
log.cancel(
|
||||||
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
|
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
|
||||||
f'{msg["error"]["tb_str"]}'
|
f'{msg["error"]["tb_str"]}'
|
||||||
)
|
)
|
||||||
error = unpack_error(msg, self.chan)
|
if self._cancel_called:
|
||||||
if (
|
|
||||||
isinstance(error, ContextCancelled) and
|
|
||||||
self._cancel_called
|
|
||||||
):
|
|
||||||
# this is an expected cancel request response message
|
# this is an expected cancel request response message
|
||||||
# and we don't need to raise it in scope since it will
|
# and we don't need to raise it in scope since it will
|
||||||
# potentially override a real error
|
# potentially override a real error
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
log.error(
|
||||||
|
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
|
||||||
|
f'{msg["error"]["tb_str"]}'
|
||||||
|
)
|
||||||
|
|
||||||
self._error = error
|
self._remote_ctx_error = error
|
||||||
|
|
||||||
# TODO: tempted to **not** do this by-reraising in a
|
# TODO: tempted to **not** do this by-reraising in a
|
||||||
# nursery and instead cancel a surrounding scope, detect
|
# nursery and instead cancel a surrounding scope, detect
|
||||||
|
@ -457,7 +470,7 @@ class Context:
|
||||||
if self._scope_nursery:
|
if self._scope_nursery:
|
||||||
|
|
||||||
async def raiser():
|
async def raiser():
|
||||||
raise self._error from None
|
raise self._remote_ctx_error from None
|
||||||
|
|
||||||
# from trio.testing import wait_all_tasks_blocked
|
# from trio.testing import wait_all_tasks_blocked
|
||||||
# await wait_all_tasks_blocked()
|
# await wait_all_tasks_blocked()
|
||||||
|
@ -483,6 +496,7 @@ class Context:
|
||||||
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')
|
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')
|
||||||
|
|
||||||
self._cancel_called = True
|
self._cancel_called = True
|
||||||
|
ipc_broken: bool = False
|
||||||
|
|
||||||
if side == 'caller':
|
if side == 'caller':
|
||||||
if not self._portal:
|
if not self._portal:
|
||||||
|
@ -500,7 +514,14 @@ class Context:
|
||||||
# NOTE: we're telling the far end actor to cancel a task
|
# NOTE: we're telling the far end actor to cancel a task
|
||||||
# corresponding to *this actor*. The far end local channel
|
# corresponding to *this actor*. The far end local channel
|
||||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
try:
|
||||||
|
await self._portal.run_from_ns(
|
||||||
|
'self',
|
||||||
|
'_cancel_task',
|
||||||
|
cid=cid,
|
||||||
|
)
|
||||||
|
except trio.BrokenResourceError:
|
||||||
|
ipc_broken = True
|
||||||
|
|
||||||
if cs.cancelled_caught:
|
if cs.cancelled_caught:
|
||||||
# XXX: there's no way to know if the remote task was indeed
|
# XXX: there's no way to know if the remote task was indeed
|
||||||
|
@ -516,7 +537,10 @@ class Context:
|
||||||
"Timed out on cancelling remote task "
|
"Timed out on cancelling remote task "
|
||||||
f"{cid} for {self._portal.channel.uid}")
|
f"{cid} for {self._portal.channel.uid}")
|
||||||
|
|
||||||
# callee side remote task
|
elif ipc_broken:
|
||||||
|
log.cancel(
|
||||||
|
"Transport layer was broken before cancel request "
|
||||||
|
f"{cid} for {self._portal.channel.uid}")
|
||||||
else:
|
else:
|
||||||
self._cancel_msg = msg
|
self._cancel_msg = msg
|
||||||
|
|
||||||
|
@ -604,6 +628,7 @@ class Context:
|
||||||
ctx=self,
|
ctx=self,
|
||||||
rx_chan=ctx._recv_chan,
|
rx_chan=ctx._recv_chan,
|
||||||
) as stream:
|
) as stream:
|
||||||
|
self._stream = stream
|
||||||
|
|
||||||
if self._portal:
|
if self._portal:
|
||||||
self._portal._streams.add(stream)
|
self._portal._streams.add(stream)
|
||||||
|
@ -645,25 +670,22 @@ class Context:
|
||||||
|
|
||||||
if not self._recv_chan._closed: # type: ignore
|
if not self._recv_chan._closed: # type: ignore
|
||||||
|
|
||||||
# wait for a final context result consuming
|
def consume(
|
||||||
# and discarding any bi dir stream msgs still
|
msg: dict,
|
||||||
# in transit from the far end.
|
|
||||||
while True:
|
|
||||||
|
|
||||||
msg = await self._recv_chan.receive()
|
) -> Optional[dict]:
|
||||||
try:
|
try:
|
||||||
self._result = msg['return']
|
return msg['return']
|
||||||
break
|
|
||||||
except KeyError as msgerr:
|
except KeyError as msgerr:
|
||||||
|
|
||||||
if 'yield' in msg:
|
if 'yield' in msg:
|
||||||
# far end task is still streaming to us so discard
|
# far end task is still streaming to us so discard
|
||||||
log.warning(f'Discarding stream delivered {msg}')
|
log.warning(f'Discarding stream delivered {msg}')
|
||||||
continue
|
return
|
||||||
|
|
||||||
elif 'stop' in msg:
|
elif 'stop' in msg:
|
||||||
log.debug('Remote stream terminated')
|
log.debug('Remote stream terminated')
|
||||||
continue
|
return
|
||||||
|
|
||||||
# internal error should never get here
|
# internal error should never get here
|
||||||
assert msg.get('cid'), (
|
assert msg.get('cid'), (
|
||||||
|
@ -673,6 +695,25 @@ class Context:
|
||||||
msg, self._portal.channel
|
msg, self._portal.channel
|
||||||
) from msgerr
|
) from msgerr
|
||||||
|
|
||||||
|
# wait for a final context result consuming
|
||||||
|
# and discarding any bi dir stream msgs still
|
||||||
|
# in transit from the far end.
|
||||||
|
if self._stream:
|
||||||
|
async with self._stream.subscribe() as bstream:
|
||||||
|
async for msg in bstream:
|
||||||
|
result = consume(msg)
|
||||||
|
if result:
|
||||||
|
self._result = result
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self._result:
|
||||||
|
while True:
|
||||||
|
msg = await self._recv_chan.receive()
|
||||||
|
result = consume(msg)
|
||||||
|
if result:
|
||||||
|
self._result = result
|
||||||
|
break
|
||||||
|
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
async def started(
|
async def started(
|
||||||
|
|
Loading…
Reference in New Issue