forked from goodboy/tractor
1
0
Fork 0

WIP porting runtime to use `Msg`-spec

runtime_to_msgspec
Tyler Goodlet 2024-04-02 13:41:52 -04:00
parent f2ce4a3469
commit e153cc0187
10 changed files with 879 additions and 478 deletions

View File

@ -53,7 +53,14 @@ from ._exceptions import (
_raise_from_no_key_in_msg, _raise_from_no_key_in_msg,
) )
from .log import get_logger from .log import get_logger
from .msg import NamespacePath from .msg import (
NamespacePath,
Msg,
Return,
Started,
Stop,
Yield,
)
from ._ipc import Channel from ._ipc import Channel
from ._streaming import MsgStream from ._streaming import MsgStream
from ._state import ( from ._state import (
@ -96,7 +103,8 @@ async def _drain_to_final_msg(
# wait for a final context result by collecting (but # wait for a final context result by collecting (but
# basically ignoring) any bi-dir-stream msgs still in transit # basically ignoring) any bi-dir-stream msgs still in transit
# from the far end. # from the far end.
pre_result_drained: list[dict] = [] # pre_result_drained: list[dict] = []
pre_result_drained: list[Msg] = []
while not ( while not (
ctx.maybe_error ctx.maybe_error
and not ctx._final_result_is_set() and not ctx._final_result_is_set()
@ -155,7 +163,10 @@ async def _drain_to_final_msg(
# await pause() # await pause()
# pray to the `trio` gawds that we're corrent with this # pray to the `trio` gawds that we're corrent with this
msg: dict = await ctx._recv_chan.receive() # msg: dict = await ctx._recv_chan.receive()
msg: Msg = await ctx._recv_chan.receive()
# always capture unexpected/non-result msgs
pre_result_drained.append(msg)
# NOTE: we get here if the far end was # NOTE: we get here if the far end was
# `ContextCancelled` in 2 cases: # `ContextCancelled` in 2 cases:
@ -175,8 +186,15 @@ async def _drain_to_final_msg(
# continue to bubble up as normal. # continue to bubble up as normal.
raise raise
try: match msg:
ctx._result: Any = msg['return'] case Return(
cid=cid,
pld=res,
):
# try:
# ctx._result: Any = msg['return']
# ctx._result: Any = msg.pld
ctx._result: Any = res
log.runtime( log.runtime(
'Context delivered final draining msg:\n' 'Context delivered final draining msg:\n'
f'{pformat(msg)}' f'{pformat(msg)}'
@ -188,11 +206,11 @@ async def _drain_to_final_msg(
# TODO: ^ we don't need it right? # TODO: ^ we don't need it right?
break break
except KeyError: # except KeyError:
# always capture unexpected/non-result msgs # except AttributeError:
pre_result_drained.append(msg) case Yield():
# if 'yield' in msg:
if 'yield' in msg:
# far end task is still streaming to us so discard # far end task is still streaming to us so discard
# and report per local context state. # and report per local context state.
if ( if (
@ -238,9 +256,10 @@ async def _drain_to_final_msg(
# TODO: work out edge cases here where # TODO: work out edge cases here where
# a stream is open but the task also calls # a stream is open but the task also calls
# this? # this?
# -[ ] should be a runtime error if a stream is open # -[ ] should be a runtime error if a stream is open right?
# right? # Stop()
elif 'stop' in msg: case Stop():
# elif 'stop' in msg:
log.cancel( log.cancel(
'Remote stream terminated due to "stop" msg:\n\n' 'Remote stream terminated due to "stop" msg:\n\n'
f'{pformat(msg)}\n' f'{pformat(msg)}\n'
@ -249,7 +268,9 @@ async def _drain_to_final_msg(
# It's an internal error if any other msg type without # It's an internal error if any other msg type without
# a`'cid'` field arrives here! # a`'cid'` field arrives here!
if not msg.get('cid'): case _:
# if not msg.get('cid'):
if not msg.cid:
raise InternalError( raise InternalError(
'Unexpected cid-missing msg?\n\n' 'Unexpected cid-missing msg?\n\n'
f'{msg}\n' f'{msg}\n'
@ -710,10 +731,14 @@ class Context:
async def send_stop(self) -> None: async def send_stop(self) -> None:
# await pause() # await pause()
await self.chan.send({ # await self.chan.send({
'stop': True, # # Stop(
'cid': self.cid # 'stop': True,
}) # 'cid': self.cid
# })
await self.chan.send(
Stop(cid=self.cid)
)
def _maybe_cancel_and_set_remote_error( def _maybe_cancel_and_set_remote_error(
self, self,
@ -1395,17 +1420,19 @@ class Context:
for msg in drained_msgs: for msg in drained_msgs:
# TODO: mask this by default.. # TODO: mask this by default..
if 'return' in msg: # if 'return' in msg:
if isinstance(msg, Return):
# from .devx import pause # from .devx import pause
# await pause() # await pause()
raise InternalError( # raise InternalError(
log.warning(
'Final `return` msg should never be drained !?!?\n\n' 'Final `return` msg should never be drained !?!?\n\n'
f'{msg}\n' f'{msg}\n'
) )
log.cancel( log.cancel(
'Ctx drained pre-result msgs:\n' 'Ctx drained pre-result msgs:\n'
f'{drained_msgs}' f'{pformat(drained_msgs)}'
) )
self.maybe_raise( self.maybe_raise(
@ -1613,7 +1640,18 @@ class Context:
f'called `.started()` twice on context with {self.chan.uid}' f'called `.started()` twice on context with {self.chan.uid}'
) )
await self.chan.send({'started': value, 'cid': self.cid}) # await self.chan.send(
# {
# 'started': value,
# 'cid': self.cid,
# }
# )
await self.chan.send(
Started(
cid=self.cid,
pld=value,
)
)
self._started_called = True self._started_called = True
async def _drain_overflows( async def _drain_overflows(
@ -1668,7 +1706,8 @@ class Context:
async def _deliver_msg( async def _deliver_msg(
self, self,
msg: dict, # msg: dict,
msg: Msg,
) -> bool: ) -> bool:
''' '''
@ -1852,7 +1891,7 @@ class Context:
# anything different. # anything different.
return False return False
else: else:
txt += f'\n{msg}\n' # txt += f'\n{msg}\n'
# raise local overrun and immediately pack as IPC # raise local overrun and immediately pack as IPC
# msg for far end. # msg for far end.
try: try:
@ -1983,15 +2022,17 @@ async def open_context_from_portal(
) )
assert ctx._remote_func_type == 'context' assert ctx._remote_func_type == 'context'
msg: dict = await ctx._recv_chan.receive() msg: Started = await ctx._recv_chan.receive()
try: try:
# the "first" value here is delivered by the callee's # the "first" value here is delivered by the callee's
# ``Context.started()`` call. # ``Context.started()`` call.
first: Any = msg['started'] # first: Any = msg['started']
first: Any = msg.pld
ctx._started_called: bool = True ctx._started_called: bool = True
except KeyError as src_error: # except KeyError as src_error:
except AttributeError as src_error:
_raise_from_no_key_in_msg( _raise_from_no_key_in_msg(
ctx=ctx, ctx=ctx,
msg=msg, msg=msg,

View File

@ -135,6 +135,7 @@ def _trio_main(
run_as_asyncio_guest(trio_main) run_as_asyncio_guest(trio_main)
else: else:
trio.run(trio_main) trio.run(trio_main)
except KeyboardInterrupt: except KeyboardInterrupt:
log.cancel( log.cancel(
'Actor received KBI\n' 'Actor received KBI\n'

View File

@ -31,9 +31,16 @@ import textwrap
import traceback import traceback
import trio import trio
from msgspec import structs
from tractor._state import current_actor from tractor._state import current_actor
from tractor.log import get_logger from tractor.log import get_logger
from tractor.msg import (
Error,
Msg,
Stop,
Yield,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ._context import Context from ._context import Context
@ -135,6 +142,8 @@ class RemoteActorError(Exception):
# and instead render if from `.boxed_type_str`? # and instead render if from `.boxed_type_str`?
self._boxed_type: BaseException = boxed_type self._boxed_type: BaseException = boxed_type
self._src_type: BaseException|None = None self._src_type: BaseException|None = None
# TODO: make this a `.errmsg: Error` throughout?
self.msgdata: dict[str, Any] = msgdata self.msgdata: dict[str, Any] = msgdata
# TODO: mask out eventually or place in `pack_error()` # TODO: mask out eventually or place in `pack_error()`
@ -464,7 +473,23 @@ class AsyncioCancelled(Exception):
''' '''
class MessagingError(Exception): class MessagingError(Exception):
'Some kind of unexpected SC messaging dialog issue' '''
IPC related msg (typing), transaction (ordering) or dialog
handling error.
'''
class MsgTypeError(MessagingError):
'''
Equivalent of a `TypeError` for an IPC wire-message
due to an invalid field value (type).
Normally this is re-raised from some `.msg._codec`
decode error raised by a backend interchange lib
like `msgspec` or `pycapnproto`.
'''
def pack_error( def pack_error(
@ -473,7 +498,7 @@ def pack_error(
tb: str|None = None, tb: str|None = None,
cid: str|None = None, cid: str|None = None,
) -> dict[str, dict]: ) -> Error|dict[str, dict]:
''' '''
Create an "error message" which boxes a locally caught Create an "error message" which boxes a locally caught
exception's meta-data and encodes it for wire transport via an exception's meta-data and encodes it for wire transport via an
@ -536,17 +561,23 @@ def pack_error(
# content's `.msgdata`). # content's `.msgdata`).
error_msg['tb_str'] = tb_str error_msg['tb_str'] = tb_str
pkt: dict = { # Error()
'error': error_msg, # pkt: dict = {
} # 'error': error_msg,
if cid: # }
pkt['cid'] = cid pkt: Error = Error(
cid=cid,
**error_msg,
# TODO: just get rid of `.pld` on this msg?
)
# if cid:
# pkt['cid'] = cid
return pkt return pkt
def unpack_error( def unpack_error(
msg: dict[str, Any], msg: dict[str, Any]|Error,
chan: Channel|None = None, chan: Channel|None = None,
box_type: RemoteActorError = RemoteActorError, box_type: RemoteActorError = RemoteActorError,
@ -564,15 +595,17 @@ def unpack_error(
''' '''
__tracebackhide__: bool = hide_tb __tracebackhide__: bool = hide_tb
error_dict: dict[str, dict] | None error_dict: dict[str, dict]|None
if ( if not isinstance(msg, Error):
error_dict := msg.get('error') # if (
) is None: # error_dict := msg.get('error')
# ) is None:
# no error field, nothing to unpack. # no error field, nothing to unpack.
return None return None
# retrieve the remote error's msg encoded details # retrieve the remote error's msg encoded details
tb_str: str = error_dict.get('tb_str', '') # tb_str: str = error_dict.get('tb_str', '')
tb_str: str = msg.tb_str
message: str = ( message: str = (
f'{chan.uid}\n' f'{chan.uid}\n'
+ +
@ -581,7 +614,8 @@ def unpack_error(
# try to lookup a suitable error type from the local runtime # try to lookup a suitable error type from the local runtime
# env then use it to construct a local instance. # env then use it to construct a local instance.
boxed_type_str: str = error_dict['boxed_type_str'] # boxed_type_str: str = error_dict['boxed_type_str']
boxed_type_str: str = msg.boxed_type_str
boxed_type: Type[BaseException] = get_err_type(boxed_type_str) boxed_type: Type[BaseException] = get_err_type(boxed_type_str)
if boxed_type_str == 'ContextCancelled': if boxed_type_str == 'ContextCancelled':
@ -595,7 +629,11 @@ def unpack_error(
# original source error. # original source error.
elif boxed_type_str == 'RemoteActorError': elif boxed_type_str == 'RemoteActorError':
assert boxed_type is RemoteActorError assert boxed_type is RemoteActorError
assert len(error_dict['relay_path']) >= 1 # assert len(error_dict['relay_path']) >= 1
assert len(msg.relay_path) >= 1
# TODO: mk RAE just take the `Error` instance directly?
error_dict: dict = structs.asdict(msg)
exc = box_type( exc = box_type(
message, message,
@ -623,11 +661,12 @@ def is_multi_cancelled(exc: BaseException) -> bool:
def _raise_from_no_key_in_msg( def _raise_from_no_key_in_msg(
ctx: Context, ctx: Context,
msg: dict, msg: Msg,
src_err: KeyError, src_err: KeyError,
log: StackLevelAdapter, # caller specific `log` obj log: StackLevelAdapter, # caller specific `log` obj
expect_key: str = 'yield', expect_key: str = 'yield',
expect_msg: str = Yield,
stream: MsgStream | None = None, stream: MsgStream | None = None,
# allow "deeper" tbs when debugging B^o # allow "deeper" tbs when debugging B^o
@ -660,8 +699,10 @@ def _raise_from_no_key_in_msg(
# an internal error should never get here # an internal error should never get here
try: try:
cid: str = msg['cid'] cid: str = msg.cid
except KeyError as src_err: # cid: str = msg['cid']
# except KeyError as src_err:
except AttributeError as src_err:
raise MessagingError( raise MessagingError(
f'IPC `Context` rx-ed msg without a ctx-id (cid)!?\n' f'IPC `Context` rx-ed msg without a ctx-id (cid)!?\n'
f'cid: {cid}\n\n' f'cid: {cid}\n\n'
@ -672,7 +713,10 @@ def _raise_from_no_key_in_msg(
# TODO: test that shows stream raising an expected error!!! # TODO: test that shows stream raising an expected error!!!
# raise the error message in a boxed exception type! # raise the error message in a boxed exception type!
if msg.get('error'): # if msg.get('error'):
if isinstance(msg, Error):
# match msg:
# case Error():
raise unpack_error( raise unpack_error(
msg, msg,
ctx.chan, ctx.chan,
@ -683,8 +727,10 @@ def _raise_from_no_key_in_msg(
# `MsgStream` termination msg. # `MsgStream` termination msg.
# TODO: does it make more sense to pack # TODO: does it make more sense to pack
# the stream._eoc outside this in the calleer always? # the stream._eoc outside this in the calleer always?
# case Stop():
elif ( elif (
msg.get('stop') # msg.get('stop')
isinstance(msg, Stop)
or ( or (
stream stream
and stream._eoc and stream._eoc
@ -725,14 +771,16 @@ def _raise_from_no_key_in_msg(
stream stream
and stream._closed and stream._closed
): ):
raise trio.ClosedResourceError('This stream was closed') # TODO: our own error subtype?
raise trio.ClosedResourceError(
'This stream was closed'
)
# always re-raise the source error if no translation error case # always re-raise the source error if no translation error case
# is activated above. # is activated above.
_type: str = 'Stream' if stream else 'Context' _type: str = 'Stream' if stream else 'Context'
raise MessagingError( raise MessagingError(
f"{_type} was expecting a '{expect_key}' message" f"{_type} was expecting a '{expect_key.upper()}' message"
" BUT received a non-error msg:\n" " BUT received a non-error msg:\n"
f'{pformat(msg)}' f'{pformat(msg)}'
) from src_err ) from src_err

View File

@ -38,17 +38,23 @@ from typing import (
Protocol, Protocol,
Type, Type,
TypeVar, TypeVar,
Union,
) )
import msgspec
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import trio import trio
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import TransportClosed from tractor._exceptions import (
TransportClosed,
MsgTypeError,
)
from tractor.msg import ( from tractor.msg import (
_ctxvar_MsgCodec, _ctxvar_MsgCodec,
_codec,
MsgCodec, MsgCodec,
mk_codec, types,
) )
log = get_logger(__name__) log = get_logger(__name__)
@ -163,7 +169,16 @@ class MsgpackTCPStream(MsgTransport):
# allow for custom IPC msg interchange format # allow for custom IPC msg interchange format
# dynamic override Bo # dynamic override Bo
self.codec: MsgCodec = codec or mk_codec() self._task = trio.lowlevel.current_task()
self._codec: MsgCodec = (
codec
or
_codec._ctxvar_MsgCodec.get()
)
log.critical(
'!?!: USING STD `tractor` CODEC !?!?\n'
f'{self._codec}\n'
)
async def _iter_packets(self) -> AsyncGenerator[dict, None]: async def _iter_packets(self) -> AsyncGenerator[dict, None]:
''' '''
@ -171,7 +186,6 @@ class MsgpackTCPStream(MsgTransport):
stream using the current task's `MsgCodec`. stream using the current task's `MsgCodec`.
''' '''
import msgspec # noqa
decodes_failed: int = 0 decodes_failed: int = 0
while True: while True:
@ -206,7 +220,19 @@ class MsgpackTCPStream(MsgTransport):
try: try:
# NOTE: lookup the `trio.Task.context`'s var for # NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`. # the current `MsgCodec`.
yield _ctxvar_MsgCodec.get().decode(msg_bytes) codec: MsgCodec = _ctxvar_MsgCodec.get()
if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
self._codec = codec
log.critical(
'.recv() USING NEW CODEC !?!?\n'
f'{self._codec}\n\n'
f'msg_bytes -> {msg_bytes}\n'
)
yield codec.decode(msg_bytes)
# TODO: remove, was only for orig draft impl # TODO: remove, was only for orig draft impl
# testing. # testing.
@ -221,6 +247,41 @@ class MsgpackTCPStream(MsgTransport):
# #
# yield obj # yield obj
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
# decode the msg-bytes using the std msgpack
# interchange-prot (i.e. without any
# `msgspec.Struct` handling) so that we can
# determine what `.msg.types.Msg` is the culprit
# by reporting the received value.
msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
msg_type_name: str = msg_dict['msg_type']
msg_type = getattr(types, msg_type_name)
errmsg: str = (
f'Received invalid IPC `{msg_type_name}` msg\n\n'
)
# XXX see if we can determine the exact invalid field
# such that we can comprehensively report the
# specific field's type problem
msgspec_msg: str = verr.args[0].rstrip('`')
msg, _, maybe_field = msgspec_msg.rpartition('$.')
if field_val := msg_dict.get(maybe_field):
field_type: Union[Type] = msg_type.__signature__.parameters[
maybe_field
].annotation
errmsg += (
f'{msg.rstrip("`")}\n\n'
f'{msg_type}\n'
f' |_.{maybe_field}: {field_type} = {field_val}\n'
)
raise MsgTypeError(errmsg) from verr
except ( except (
msgspec.DecodeError, msgspec.DecodeError,
UnicodeDecodeError, UnicodeDecodeError,
@ -230,14 +291,15 @@ class MsgpackTCPStream(MsgTransport):
# do with a channel drop - hope that receiving from the # do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up. # channel will raise an expected error and bubble up.
try: try:
msg_str: str | bytes = msg_bytes.decode() msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError: except UnicodeDecodeError:
msg_str = msg_bytes msg_str = msg_bytes
log.error( log.exception(
'`msgspec` failed to decode!?\n' 'Failed to decode msg?\n'
'dumping bytes:\n' f'{codec}\n\n'
f'{msg_str!r}' 'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
) )
decodes_failed += 1 decodes_failed += 1
else: else:
@ -258,8 +320,21 @@ class MsgpackTCPStream(MsgTransport):
# NOTE: lookup the `trio.Task.context`'s var for # NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`. # the current `MsgCodec`.
bytes_data: bytes = _ctxvar_MsgCodec.get().encode(msg) codec: MsgCodec = _ctxvar_MsgCodec.get()
# bytes_data: bytes = self.codec.encode(msg) # if self._codec != codec:
if self._codec.pld_spec != codec.pld_spec:
self._codec = codec
log.critical(
'.send() using NEW CODEC !?!?\n'
f'{self._codec}\n\n'
f'OBJ -> {msg}\n'
)
if type(msg) not in types.__spec__:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
bytes_data: bytes = codec.encode(msg)
# supposedly the fastest says, # supposedly the fastest says,
# https://stackoverflow.com/a/54027962 # https://stackoverflow.com/a/54027962

View File

@ -45,7 +45,10 @@ from ._state import (
) )
from ._ipc import Channel from ._ipc import Channel
from .log import get_logger from .log import get_logger
from .msg import NamespacePath from .msg import (
NamespacePath,
Return,
)
from ._exceptions import ( from ._exceptions import (
unpack_error, unpack_error,
NoResult, NoResult,
@ -66,7 +69,8 @@ log = get_logger(__name__)
# `._raise_from_no_key_in_msg()` (after tweak to # `._raise_from_no_key_in_msg()` (after tweak to
# accept a `chan: Channel` arg) in key block! # accept a `chan: Channel` arg) in key block!
def _unwrap_msg( def _unwrap_msg(
msg: dict[str, Any], # msg: dict[str, Any],
msg: Return,
channel: Channel, channel: Channel,
hide_tb: bool = True, hide_tb: bool = True,
@ -79,18 +83,21 @@ def _unwrap_msg(
__tracebackhide__: bool = hide_tb __tracebackhide__: bool = hide_tb
try: try:
return msg['return'] return msg.pld
except KeyError as ke: # return msg['return']
# except KeyError as ke:
except AttributeError as err:
# internal error should never get here # internal error should never get here
assert msg.get('cid'), ( # assert msg.get('cid'), (
assert msg.cid, (
"Received internal error at portal?" "Received internal error at portal?"
) )
raise unpack_error( raise unpack_error(
msg, msg,
channel channel
) from ke ) from err
class Portal: class Portal:

View File

@ -55,12 +55,21 @@ from ._exceptions import (
TransportClosed, TransportClosed,
) )
from .devx import ( from .devx import (
# pause, pause,
maybe_wait_for_debugger, maybe_wait_for_debugger,
_debug, _debug,
) )
from . import _state from . import _state
from .log import get_logger from .log import get_logger
from tractor.msg.types import (
Start,
StartAck,
Started,
Stop,
Yield,
Return,
Error,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -89,10 +98,13 @@ async def _invoke_non_context(
# TODO: can we unify this with the `context=True` impl below? # TODO: can we unify this with the `context=True` impl below?
if inspect.isasyncgen(coro): if inspect.isasyncgen(coro):
await chan.send({ # await chan.send({
'cid': cid, await chan.send(
'functype': 'asyncgen', StartAck(
}) cid=cid,
functype='asyncgen',
)
)
# XXX: massive gotcha! If the containing scope # XXX: massive gotcha! If the containing scope
# is cancelled and we execute the below line, # is cancelled and we execute the below line,
# any ``ActorNursery.__aexit__()`` WON'T be # any ``ActorNursery.__aexit__()`` WON'T be
@ -112,27 +124,45 @@ async def _invoke_non_context(
# to_send = await chan.recv_nowait() # to_send = await chan.recv_nowait()
# if to_send is not None: # if to_send is not None:
# to_yield = await coro.asend(to_send) # to_yield = await coro.asend(to_send)
await chan.send({ # await chan.send({
'yield': item, # # Yield()
'cid': cid, # 'cid': cid,
}) # 'yield': item,
# })
await chan.send(
Yield(
cid=cid,
pld=item,
)
)
log.runtime(f"Finished iterating {coro}") log.runtime(f"Finished iterating {coro}")
# TODO: we should really support a proper # TODO: we should really support a proper
# `StopAsyncIteration` system here for returning a final # `StopAsyncIteration` system here for returning a final
# value if desired # value if desired
await chan.send({ await chan.send(
'stop': True, Stop(cid=cid)
'cid': cid, )
}) # await chan.send({
# # Stop(
# 'cid': cid,
# 'stop': True,
# })
# one way @stream func that gets treated like an async gen # one way @stream func that gets treated like an async gen
# TODO: can we unify this with the `context=True` impl below? # TODO: can we unify this with the `context=True` impl below?
elif treat_as_gen: elif treat_as_gen:
await chan.send({ await chan.send(
'cid': cid, StartAck(
'functype': 'asyncgen', cid=cid,
}) functype='asyncgen',
)
)
# await chan.send({
# # StartAck()
# 'cid': cid,
# 'functype': 'asyncgen',
# })
# XXX: the async-func may spawn further tasks which push # XXX: the async-func may spawn further tasks which push
# back values like an async-generator would but must # back values like an async-generator would but must
# manualy construct the response dict-packet-responses as # manualy construct the response dict-packet-responses as
@ -145,10 +175,14 @@ async def _invoke_non_context(
if not cs.cancelled_caught: if not cs.cancelled_caught:
# task was not cancelled so we can instruct the # task was not cancelled so we can instruct the
# far end async gen to tear down # far end async gen to tear down
await chan.send({ await chan.send(
'stop': True, Stop(cid=cid)
'cid': cid )
}) # await chan.send({
# # Stop(
# 'cid': cid,
# 'stop': True,
# })
else: else:
# regular async function/method # regular async function/method
# XXX: possibly just a scheduled `Actor._cancel_task()` # XXX: possibly just a scheduled `Actor._cancel_task()`
@ -160,10 +194,17 @@ async def _invoke_non_context(
# way: using the linked IPC context machinery. # way: using the linked IPC context machinery.
failed_resp: bool = False failed_resp: bool = False
try: try:
await chan.send({ await chan.send(
'functype': 'asyncfunc', StartAck(
'cid': cid cid=cid,
}) functype='asyncfunc',
)
)
# await chan.send({
# # StartAck()
# 'cid': cid,
# 'functype': 'asyncfunc',
# })
except ( except (
trio.ClosedResourceError, trio.ClosedResourceError,
trio.BrokenResourceError, trio.BrokenResourceError,
@ -197,10 +238,17 @@ async def _invoke_non_context(
and chan.connected() and chan.connected()
): ):
try: try:
await chan.send({ # await chan.send({
'return': result, # # Return()
'cid': cid, # 'cid': cid,
}) # 'return': result,
# })
await chan.send(
Return(
cid=cid,
pld=result,
)
)
except ( except (
BrokenPipeError, BrokenPipeError,
trio.BrokenResourceError, trio.BrokenResourceError,
@ -381,6 +429,8 @@ async def _invoke(
# XXX for .pause_from_sync()` usage we need to make sure # XXX for .pause_from_sync()` usage we need to make sure
# `greenback` is boostrapped in the subactor! # `greenback` is boostrapped in the subactor!
await _debug.maybe_init_greenback() await _debug.maybe_init_greenback()
# else:
# await pause()
# TODO: possibly a specially formatted traceback # TODO: possibly a specially formatted traceback
# (not sure what typing is for this..)? # (not sure what typing is for this..)?
@ -493,10 +543,18 @@ async def _invoke(
# a "context" endpoint type is the most general and # a "context" endpoint type is the most general and
# "least sugary" type of RPC ep with support for # "least sugary" type of RPC ep with support for
# bi-dir streaming B) # bi-dir streaming B)
await chan.send({ # StartAck
'cid': cid, await chan.send(
'functype': 'context', StartAck(
}) cid=cid,
functype='context',
)
)
# await chan.send({
# # StartAck()
# 'cid': cid,
# 'functype': 'context',
# })
# TODO: should we also use an `.open_context()` equiv # TODO: should we also use an `.open_context()` equiv
# for this callee side by factoring the impl from # for this callee side by factoring the impl from
@ -520,10 +578,17 @@ async def _invoke(
ctx._result = res ctx._result = res
# deliver final result to caller side. # deliver final result to caller side.
await chan.send({ await chan.send(
'return': res, Return(
'cid': cid cid=cid,
}) pld=res,
)
)
# await chan.send({
# # Return()
# 'cid': cid,
# 'return': res,
# })
# NOTE: this happens IFF `ctx._scope.cancel()` is # NOTE: this happens IFF `ctx._scope.cancel()` is
# called by any of, # called by any of,
@ -696,7 +761,8 @@ async def try_ship_error_to_remote(
try: try:
# NOTE: normally only used for internal runtime errors # NOTE: normally only used for internal runtime errors
# so ship to peer actor without a cid. # so ship to peer actor without a cid.
msg: dict = pack_error( # msg: dict = pack_error(
msg: Error = pack_error(
err, err,
cid=cid, cid=cid,
@ -712,12 +778,13 @@ async def try_ship_error_to_remote(
trio.BrokenResourceError, trio.BrokenResourceError,
BrokenPipeError, BrokenPipeError,
): ):
err_msg: dict = msg['error']['tb_str'] # err_msg: dict = msg['error']['tb_str']
log.critical( log.critical(
'IPC transport failure -> ' 'IPC transport failure -> '
f'failed to ship error to {remote_descr}!\n\n' f'failed to ship error to {remote_descr}!\n\n'
f'X=> {channel.uid}\n\n' f'X=> {channel.uid}\n\n'
f'{err_msg}\n' # f'{err_msg}\n'
f'{msg}\n'
) )
@ -777,9 +844,20 @@ async def process_messages(
with CancelScope(shield=shield) as loop_cs: with CancelScope(shield=shield) as loop_cs:
task_status.started(loop_cs) task_status.started(loop_cs)
async for msg in chan: async for msg in chan:
log.transport( # type: ignore
f'<= IPC msg from peer: {chan.uid}\n\n'
# TODO: conditionally avoid fmting depending
# on log level (for perf)?
# => specifically `pformat()` sub-call..?
f'{pformat(msg)}\n'
)
match msg:
# if msg is None:
# dedicated loop terminate sentinel # dedicated loop terminate sentinel
if msg is None: case None:
tasks: dict[ tasks: dict[
tuple[Channel, str], tuple[Channel, str],
@ -802,17 +880,16 @@ async def process_messages(
) )
break break
log.transport( # type: ignore # cid = msg.get('cid')
f'<= IPC msg from peer: {chan.uid}\n\n' # if cid:
case (
# TODO: conditionally avoid fmting depending StartAck(cid=cid)
# on log level (for perf)? | Started(cid=cid)
# => specifically `pformat()` sub-call..? | Yield(cid=cid)
f'{pformat(msg)}\n' | Stop(cid=cid)
) | Return(cid=cid)
| Error(cid=cid)
cid = msg.get('cid') ):
if cid:
# deliver response to local caller/waiter # deliver response to local caller/waiter
# via its per-remote-context memory channel. # via its per-remote-context memory channel.
await actor._push_result( await actor._push_result(
@ -835,32 +912,44 @@ async def process_messages(
# -[ ] implement with ``match:`` syntax? # -[ ] implement with ``match:`` syntax?
# -[ ] discard un-authed msgs as per, # -[ ] discard un-authed msgs as per,
# <TODO put issue for typed msging structs> # <TODO put issue for typed msging structs>
try: case Start(
( cid=cid,
ns, ns=ns,
funcname, func=funcname,
kwargs, kwargs=kwargs,
actorid, uid=actorid,
cid, ):
) = msg['cmd'] # try:
# (
# ns,
# funcname,
# kwargs,
# actorid,
# cid,
# ) = msg['cmd']
except KeyError: # # TODO: put in `case Error():` right?
# This is the non-rpc error case, that is, an # except KeyError:
# error **not** raised inside a call to ``_invoke()`` # # This is the non-rpc error case, that is, an
# (i.e. no cid was provided in the msg - see above). # # error **not** raised inside a call to ``_invoke()``
# Push this error to all local channel consumers # # (i.e. no cid was provided in the msg - see above).
# (normally portals) by marking the channel as errored # # Push this error to all local channel consumers
assert chan.uid # # (normally portals) by marking the channel as errored
exc = unpack_error(msg, chan=chan) # assert chan.uid
chan._exc = exc # exc = unpack_error(msg, chan=chan)
raise exc # chan._exc = exc
# raise exc
log.runtime( log.runtime(
'Handling RPC cmd from\n' 'Handling RPC `Start` request from\n'
f'peer: {actorid}\n' f'peer: {actorid}\n'
'\n' '\n'
f'=> {ns}.{funcname}({kwargs})\n' f'=> {ns}.{funcname}({kwargs})\n'
) )
# case Start(
# ns='self',
# funcname='cancel',
# ):
if ns == 'self': if ns == 'self':
if funcname == 'cancel': if funcname == 'cancel':
func: Callable = actor.cancel func: Callable = actor.cancel
@ -896,6 +985,10 @@ async def process_messages(
loop_cs.cancel() loop_cs.cancel()
break break
# case Start(
# ns='self',
# funcname='_cancel_task',
# ):
if funcname == '_cancel_task': if funcname == '_cancel_task':
func: Callable = actor._cancel_task func: Callable = actor._cancel_task
@ -937,11 +1030,20 @@ async def process_messages(
f' |_cid: {target_cid}\n' f' |_cid: {target_cid}\n'
) )
continue continue
# case Start(
# ns='self',
# funcname='register_actor',
# ):
else: else:
# normally registry methods, eg. # normally registry methods, eg.
# ``.register_actor()`` etc. # ``.register_actor()`` etc.
func: Callable = getattr(actor, funcname) func: Callable = getattr(actor, funcname)
# case Start(
# ns=str(),
# funcname=funcname,
# ):
else: else:
# complain to client about restricted modules # complain to client about restricted modules
try: try:
@ -1022,6 +1124,24 @@ async def process_messages(
trio.Event(), trio.Event(),
) )
case Error()|_:
# This is the non-rpc error case, that is, an
# error **not** raised inside a call to ``_invoke()``
# (i.e. no cid was provided in the msg - see above).
# Push this error to all local channel consumers
# (normally portals) by marking the channel as errored
log.exception(
f'Unhandled IPC msg:\n\n'
f'{msg}\n'
)
assert chan.uid
exc = unpack_error(
msg,
chan=chan,
)
chan._exc = exc
raise exc
log.runtime( log.runtime(
'Waiting on next IPC msg from\n' 'Waiting on next IPC msg from\n'
f'peer: {chan.uid}\n' f'peer: {chan.uid}\n'

View File

@ -91,6 +91,23 @@ from ._rpc import (
process_messages, process_messages,
try_ship_error_to_remote, try_ship_error_to_remote,
) )
from tractor.msg import (
types as msgtypes,
pretty_struct,
)
# from tractor.msg.types import (
# Aid,
# SpawnSpec,
# Start,
# StartAck,
# Started,
# Yield,
# Stop,
# Return,
# Error,
# )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -147,6 +164,7 @@ class Actor:
# Information about `__main__` from parent # Information about `__main__` from parent
_parent_main_data: dict[str, str] _parent_main_data: dict[str, str]
_parent_chan_cs: CancelScope|None = None _parent_chan_cs: CancelScope|None = None
_spawn_spec: SpawnSpec|None = None
# syncs for setup/teardown sequences # syncs for setup/teardown sequences
_server_down: trio.Event|None = None _server_down: trio.Event|None = None
@ -537,7 +555,8 @@ class Actor:
f'{pformat(msg)}\n' f'{pformat(msg)}\n'
) )
cid = msg.get('cid') # cid: str|None = msg.get('cid')
cid: str|None = msg.cid
if cid: if cid:
# deliver response to local caller/waiter # deliver response to local caller/waiter
await self._push_result( await self._push_result(
@ -889,29 +908,44 @@ class Actor:
f'=> {ns}.{func}({kwargs})\n' f'=> {ns}.{func}({kwargs})\n'
) )
await chan.send( await chan.send(
{'cmd': ( msgtypes.Start(
ns, ns=ns,
func, func=func,
kwargs, kwargs=kwargs,
self.uid, uid=self.uid,
cid, cid=cid,
)}
) )
)
# {'cmd': (
# ns,
# func,
# kwargs,
# self.uid,
# cid,
# )}
# )
# Wait on first response msg and validate; this should be # Wait on first response msg and validate; this should be
# immediate. # immediate.
first_msg: dict = await ctx._recv_chan.receive() # first_msg: dict = await ctx._recv_chan.receive()
functype: str = first_msg.get('functype') # functype: str = first_msg.get('functype')
if 'error' in first_msg: first_msg: msgtypes.StartAck = await ctx._recv_chan.receive()
try:
functype: str = first_msg.functype
except AttributeError:
raise unpack_error(first_msg, chan) raise unpack_error(first_msg, chan)
# if 'error' in first_msg:
# raise unpack_error(first_msg, chan)
elif functype not in ( if functype not in (
'asyncfunc', 'asyncfunc',
'asyncgen', 'asyncgen',
'context', 'context',
): ):
raise ValueError(f"{first_msg} is an invalid response packet?") raise ValueError(
f'{first_msg} is an invalid response packet?'
)
ctx._remote_func_type = functype ctx._remote_func_type = functype
return ctx return ctx
@ -944,24 +978,36 @@ class Actor:
await self._do_handshake(chan) await self._do_handshake(chan)
accept_addrs: list[tuple[str, int]]|None = None accept_addrs: list[tuple[str, int]]|None = None
if self._spawn_method == "trio": if self._spawn_method == "trio":
# Receive runtime state from our parent # Receive runtime state from our parent
parent_data: dict[str, Any] # parent_data: dict[str, Any]
parent_data = await chan.recv() # parent_data = await chan.recv()
log.runtime(
'Received state from parent:\n\n' # TODO: maybe we should just wrap this directly
# in a `Actor.spawn_info: SpawnInfo` struct?
spawnspec: msgtypes.SpawnSpec = await chan.recv()
self._spawn_spec = spawnspec
# TODO: eventually all these msgs as # TODO: eventually all these msgs as
# `msgspec.Struct` with a special mode that # `msgspec.Struct` with a special mode that
# pformats them in multi-line mode, BUT only # pformats them in multi-line mode, BUT only
# if "trace"/"util" mode is enabled? # if "trace"/"util" mode is enabled?
f'{pformat(parent_data)}\n' log.runtime(
'Received runtime spec from parent:\n\n'
f'{pformat(spawnspec)}\n'
) )
accept_addrs: list[tuple[str, int]] = parent_data.pop('bind_addrs') # accept_addrs: list[tuple[str, int]] = parent_data.pop('bind_addrs')
rvs = parent_data.pop('_runtime_vars') accept_addrs: list[tuple[str, int]] = spawnspec.bind_addrs
# rvs = parent_data.pop('_runtime_vars')
rvs = spawnspec._runtime_vars
if rvs['_debug_mode']: if rvs['_debug_mode']:
try: try:
log.info('Enabling `stackscope` traces on SIGUSR1') log.info(
'Enabling `stackscope` traces on SIGUSR1'
)
from .devx import enable_stack_on_sig from .devx import enable_stack_on_sig
enable_stack_on_sig() enable_stack_on_sig()
except ImportError: except ImportError:
@ -969,28 +1015,40 @@ class Actor:
'`stackscope` not installed for use in debug mode!' '`stackscope` not installed for use in debug mode!'
) )
log.runtime(f"Runtime vars are: {rvs}") log.runtime(f'Runtime vars are: {rvs}')
rvs['_is_root'] = False rvs['_is_root'] = False
_state._runtime_vars.update(rvs) _state._runtime_vars.update(rvs)
for attr, value in parent_data.items():
if (
attr == 'reg_addrs'
and value
):
# XXX: ``msgspec`` doesn't support serializing tuples # XXX: ``msgspec`` doesn't support serializing tuples
# so just cash manually here since it's what our # so just cash manually here since it's what our
# internals expect. # internals expect.
# TODO: we don't really NEED these as #
# tuples so we can probably drop this self.reg_addrs = [
# casting since apparently in python lists # TODO: we don't really NEED these as tuples?
# are "more efficient"? # so we can probably drop this casting since
self.reg_addrs = [tuple(val) for val in value] # apparently in python lists are "more
# efficient"?
tuple(val)
for val in spawnspec.reg_addrs
]
else: # for attr, value in parent_data.items():
for _, attr, value in pretty_struct.iter_fields(
spawnspec,
):
setattr(self, attr, value) setattr(self, attr, value)
# if (
# attr == 'reg_addrs'
# and value
# ):
# self.reg_addrs = [tuple(val) for val in value]
# else:
# setattr(self, attr, value)
return chan, accept_addrs return (
chan,
accept_addrs,
)
except OSError: # failed to connect except OSError: # failed to connect
log.warning( log.warning(
@ -1432,7 +1490,7 @@ class Actor:
self, self,
chan: Channel chan: Channel
) -> tuple[str, str]: ) -> msgtypes.Aid:
''' '''
Exchange `(name, UUIDs)` identifiers as the first Exchange `(name, UUIDs)` identifiers as the first
communication step with any (peer) remote `Actor`. communication step with any (peer) remote `Actor`.
@ -1441,14 +1499,27 @@ class Actor:
"actor model" parlance. "actor model" parlance.
''' '''
await chan.send(self.uid) name, uuid = self.uid
value: tuple = await chan.recv() await chan.send(
uid: tuple[str, str] = (str(value[0]), str(value[1])) msgtypes.Aid(
name=name,
uuid=uuid,
)
)
aid: msgtypes.Aid = await chan.recv()
chan.aid = aid
uid: tuple[str, str] = (
# str(value[0]),
# str(value[1])
aid.name,
aid.uuid,
)
if not isinstance(uid, tuple): if not isinstance(uid, tuple):
raise ValueError(f"{uid} is not a valid uid?!") raise ValueError(f"{uid} is not a valid uid?!")
chan.uid = str(uid[0]), str(uid[1]) chan.uid = uid
return uid return uid
def is_infected_aio(self) -> bool: def is_infected_aio(self) -> bool:
@ -1508,7 +1579,8 @@ async def async_main(
# because we're running in mp mode # because we're running in mp mode
if ( if (
set_accept_addr_says_rent set_accept_addr_says_rent
and set_accept_addr_says_rent is not None and
set_accept_addr_says_rent is not None
): ):
accept_addrs = set_accept_addr_says_rent accept_addrs = set_accept_addr_says_rent

View File

@ -49,6 +49,9 @@ from tractor._portal import Portal
from tractor._runtime import Actor from tractor._runtime import Actor
from tractor._entry import _mp_main from tractor._entry import _mp_main
from tractor._exceptions import ActorFailure from tractor._exceptions import ActorFailure
from tractor.msg.types import (
SpawnSpec,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -489,14 +492,25 @@ async def trio_proc(
portal, portal,
) )
# send additional init params # send a "spawning specification" which configures the
await chan.send({ # initial runtime state of the child.
'_parent_main_data': subactor._parent_main_data, await chan.send(
'enable_modules': subactor.enable_modules, SpawnSpec(
'reg_addrs': subactor.reg_addrs, _parent_main_data=subactor._parent_main_data,
'bind_addrs': bind_addrs, enable_modules=subactor.enable_modules,
'_runtime_vars': _runtime_vars, reg_addrs=subactor.reg_addrs,
}) bind_addrs=bind_addrs,
_runtime_vars=_runtime_vars,
)
)
# await chan.send({
# '_parent_main_data': subactor._parent_main_data,
# 'enable_modules': subactor.enable_modules,
# 'reg_addrs': subactor.reg_addrs,
# 'bind_addrs': bind_addrs,
# '_runtime_vars': _runtime_vars,
# })
# track subactor in current nursery # track subactor in current nursery
curr_actor = current_actor() curr_actor = current_actor()

View File

@ -43,6 +43,11 @@ from .trionics import (
broadcast_receiver, broadcast_receiver,
BroadcastReceiver, BroadcastReceiver,
) )
from tractor.msg import (
Stop,
Yield,
Error,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ._context import Context from ._context import Context
@ -94,21 +99,25 @@ class MsgStream(trio.abc.Channel):
self, self,
allow_msg_keys: list[str] = ['yield'], allow_msg_keys: list[str] = ['yield'],
): ):
msg: dict = self._rx_chan.receive_nowait() # msg: dict = self._rx_chan.receive_nowait()
msg: Yield|Stop = self._rx_chan.receive_nowait()
for ( for (
i, i,
key, key,
) in enumerate(allow_msg_keys): ) in enumerate(allow_msg_keys):
try: try:
return msg[key] # return msg[key]
except KeyError as kerr: return msg.pld
# except KeyError as kerr:
except AttributeError as attrerr:
if i < (len(allow_msg_keys) - 1): if i < (len(allow_msg_keys) - 1):
continue continue
_raise_from_no_key_in_msg( _raise_from_no_key_in_msg(
ctx=self._ctx, ctx=self._ctx,
msg=msg, msg=msg,
src_err=kerr, # src_err=kerr,
src_err=attrerr,
log=log, log=log,
expect_key=key, expect_key=key,
stream=self, stream=self,
@ -148,18 +157,22 @@ class MsgStream(trio.abc.Channel):
src_err: Exception|None = None # orig tb src_err: Exception|None = None # orig tb
try: try:
try: try:
msg = await self._rx_chan.receive() msg: Yield = await self._rx_chan.receive()
return msg['yield'] # return msg['yield']
return msg.pld
except KeyError as kerr: # except KeyError as kerr:
src_err = kerr except AttributeError as attrerr:
# src_err = kerr
src_err = attrerr
# NOTE: may raise any of the below error types # NOTE: may raise any of the below error types
# includg EoC when a 'stop' msg is found. # includg EoC when a 'stop' msg is found.
_raise_from_no_key_in_msg( _raise_from_no_key_in_msg(
ctx=self._ctx, ctx=self._ctx,
msg=msg, msg=msg,
src_err=kerr, # src_err=kerr,
src_err=attrerr,
log=log, log=log,
expect_key='yield', expect_key='yield',
stream=self, stream=self,
@ -514,11 +527,18 @@ class MsgStream(trio.abc.Channel):
raise self._closed raise self._closed
try: try:
# await self._ctx.chan.send(
# payload={
# 'yield': data,
# 'cid': self._ctx.cid,
# },
# # hide_tb=hide_tb,
# )
await self._ctx.chan.send( await self._ctx.chan.send(
payload={ payload=Yield(
'yield': data, cid=self._ctx.cid,
'cid': self._ctx.cid, pld=data,
}, ),
# hide_tb=hide_tb, # hide_tb=hide_tb,
) )
except ( except (

View File

@ -935,6 +935,9 @@ async def _pause(
# ``breakpoint()`` was awaited and begin handling stdio. # ``breakpoint()`` was awaited and begin handling stdio.
log.debug('Entering sync world of the `pdb` REPL..') log.debug('Entering sync world of the `pdb` REPL..')
try: try:
# log.critical(
# f'stack len: {len(pdb.stack)}\n'
# )
debug_func( debug_func(
actor, actor,
pdb, pdb,