Parameterize the `return_msg_type` in `._invoke()`

Since we also handle a runtime-specific `CancelAck`, allow the
caller-scheduler to pass in the expected return-type msg per the RPC msg
endpoint loop.
runtime_to_msgspec
Tyler Goodlet 2024-05-28 09:36:26 -04:00
parent 8b860f4245
commit 582144830f
1 changed files with 21 additions and 17 deletions

View File

@ -64,6 +64,7 @@ from .log import get_logger
from .msg import (
current_codec,
MsgCodec,
PayloadT,
NamespacePath,
pretty_struct,
)
@ -98,7 +99,7 @@ async def _invoke_non_context(
treat_as_gen: bool,
is_rpc: bool,
return_msg: Return|CancelAck = Return,
return_msg_type: Return|CancelAck = Return,
task_status: TaskStatus[
Context | BaseException
@ -220,7 +221,7 @@ async def _invoke_non_context(
and chan.connected()
):
try:
ret_msg = return_msg(
ret_msg = return_msg_type(
cid=cid,
pld=result,
)
@ -419,7 +420,7 @@ async def _invoke(
is_rpc: bool = True,
hide_tb: bool = True,
return_msg: Return|CancelAck = Return,
return_msg_type: Return|CancelAck = Return,
task_status: TaskStatus[
Context | BaseException
@ -533,7 +534,7 @@ async def _invoke(
kwargs,
treat_as_gen,
is_rpc,
return_msg,
return_msg_type,
task_status,
)
# XXX below fallthrough is ONLY for `@context` eps
@ -593,18 +594,21 @@ async def _invoke(
ctx._scope = tn.cancel_scope
task_status.started(ctx)
# TODO: should would be nice to have our
# `TaskMngr` nursery here!
res: Any = await coro
ctx._result = res
# deliver final result to caller side.
await chan.send(
return_msg(
cid=cid,
pld=res,
)
# TODO: better `trionics` tooling:
# -[ ] should would be nice to have our `TaskMngr`
# nursery here!
# -[ ] payload value checking like we do with
# `.started()` such that the debbuger can engage
# here in the child task instead of waiting for the
# parent to crash with it's own MTE..
res: Any|PayloadT = await coro
return_msg: Return|CancelAck = return_msg_type(
cid=cid,
pld=res,
)
# set and shuttle final result to "parent"-side task.
ctx._result = res
await chan.send(return_msg)
# NOTE: this happens IFF `ctx._scope.cancel()` is
# called by any of,
@ -940,7 +944,7 @@ async def process_messages(
actor.cancel,
kwargs,
is_rpc=False,
return_msg=CancelAck,
return_msg_type=CancelAck,
)
log.runtime(
@ -974,7 +978,7 @@ async def process_messages(
actor._cancel_task,
kwargs,
is_rpc=False,
return_msg=CancelAck,
return_msg_type=CancelAck,
)
except BaseException:
log.exception(