forked from goodboy/tractor
Add initial bi-directional streaming
This mostly adds the api described in https://github.com/goodboy/tractor/issues/53#issuecomment-806258798 The first draft summary: - formalize bidir steaming using the `trio.Channel` style interface which we derive as a `MsgStream` type. - add `Portal.open_context()` which provides a `trio.Nursery.start()` remote task invocation style for setting up and tearing down tasks contexts in remote actors. - add a distinct `'started'` message to the ipc protocol to facilitate `Context.start()` with a first return value. - for our `ReceiveMsgStream` type, don't cancel the remote task in `.aclose()`; this is now done explicitly by the surrounding `Context` usage: `Context.cancel()`. - streams in either direction still use a `'yield'` message keeping the proto mostly symmetric without having to worry about which side is the caller / portal opener. - subtlety: only allow sending a `'stop'` message during a 2-way streaming context from `ReceiveStream.aclose()`, detailed comment with explanation is included. Relates to #53transport_hardening
parent
76f07898d9
commit
2870828c34
|
@ -14,6 +14,7 @@ from types import ModuleType
|
|||
import sys
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
import warnings
|
||||
|
||||
import trio # type: ignore
|
||||
from trio_typing import TaskStatus
|
||||
|
@ -58,13 +59,37 @@ async def _invoke(
|
|||
treat_as_gen = False
|
||||
cs = None
|
||||
cancel_scope = trio.CancelScope()
|
||||
ctx = Context(chan, cid, cancel_scope)
|
||||
ctx = Context(chan, cid, _cancel_scope=cancel_scope)
|
||||
context = False
|
||||
|
||||
if getattr(func, '_tractor_stream_function', False):
|
||||
# handle decorated ``@tractor.stream`` async functions
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
|
||||
# compat with old api
|
||||
kwargs['ctx'] = ctx
|
||||
|
||||
if 'ctx' in params:
|
||||
warnings.warn(
|
||||
"`@tractor.stream decorated funcs should now declare "
|
||||
"a `stream` arg, `ctx` is now designated for use with "
|
||||
"@tractor.context",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
elif 'stream' in params:
|
||||
assert 'stream' in params
|
||||
kwargs['stream'] = ctx
|
||||
|
||||
treat_as_gen = True
|
||||
|
||||
elif getattr(func, '_tractor_context_function', False):
|
||||
# handle decorated ``@tractor.context`` async function
|
||||
kwargs['ctx'] = ctx
|
||||
context = True
|
||||
|
||||
# errors raised inside this block are propgated back to caller
|
||||
try:
|
||||
if not (
|
||||
|
@ -102,26 +127,41 @@ async def _invoke(
|
|||
# `StopAsyncIteration` system here for returning a final
|
||||
# value if desired
|
||||
await chan.send({'stop': True, 'cid': cid})
|
||||
|
||||
# one way @stream func that gets treated like an async gen
|
||||
elif treat_as_gen:
|
||||
await chan.send({'functype': 'asyncgen', 'cid': cid})
|
||||
# XXX: the async-func may spawn further tasks which push
|
||||
# back values like an async-generator would but must
|
||||
# manualy construct the response dict-packet-responses as
|
||||
# above
|
||||
with cancel_scope as cs:
|
||||
task_status.started(cs)
|
||||
await coro
|
||||
|
||||
if not cs.cancelled_caught:
|
||||
# task was not cancelled so we can instruct the
|
||||
# far end async gen to tear down
|
||||
await chan.send({'stop': True, 'cid': cid})
|
||||
|
||||
elif context:
|
||||
# context func with support for bi-dir streaming
|
||||
await chan.send({'functype': 'context', 'cid': cid})
|
||||
|
||||
with cancel_scope as cs:
|
||||
task_status.started(cs)
|
||||
await chan.send({'return': await coro, 'cid': cid})
|
||||
|
||||
# if cs.cancelled_caught:
|
||||
# # task was cancelled so relay to the cancel to caller
|
||||
# await chan.send({'return': await coro, 'cid': cid})
|
||||
|
||||
else:
|
||||
if treat_as_gen:
|
||||
await chan.send({'functype': 'asyncgen', 'cid': cid})
|
||||
# XXX: the async-func may spawn further tasks which push
|
||||
# back values like an async-generator would but must
|
||||
# manualy construct the response dict-packet-responses as
|
||||
# above
|
||||
with cancel_scope as cs:
|
||||
task_status.started(cs)
|
||||
await coro
|
||||
if not cs.cancelled_caught:
|
||||
# task was not cancelled so we can instruct the
|
||||
# far end async gen to tear down
|
||||
await chan.send({'stop': True, 'cid': cid})
|
||||
else:
|
||||
# regular async function
|
||||
await chan.send({'functype': 'asyncfunc', 'cid': cid})
|
||||
with cancel_scope as cs:
|
||||
task_status.started(cs)
|
||||
await chan.send({'return': await coro, 'cid': cid})
|
||||
# regular async function
|
||||
await chan.send({'functype': 'asyncfunc', 'cid': cid})
|
||||
with cancel_scope as cs:
|
||||
task_status.started(cs)
|
||||
await chan.send({'return': await coro, 'cid': cid})
|
||||
|
||||
except (Exception, trio.MultiError) as err:
|
||||
|
||||
|
@ -417,10 +457,10 @@ class Actor:
|
|||
send_chan, recv_chan = self._cids2qs[(actorid, cid)]
|
||||
assert send_chan.cid == cid # type: ignore
|
||||
|
||||
if 'stop' in msg:
|
||||
log.debug(f"{send_chan} was terminated at remote end")
|
||||
# indicate to consumer that far end has stopped
|
||||
return await send_chan.aclose()
|
||||
# if 'stop' in msg:
|
||||
# log.debug(f"{send_chan} was terminated at remote end")
|
||||
# # indicate to consumer that far end has stopped
|
||||
# return await send_chan.aclose()
|
||||
|
||||
try:
|
||||
log.debug(f"Delivering {msg} from {actorid} to caller {cid}")
|
||||
|
@ -428,6 +468,12 @@ class Actor:
|
|||
await send_chan.send(msg)
|
||||
|
||||
except trio.BrokenResourceError:
|
||||
# TODO: what is the right way to handle the case where the
|
||||
# local task has already sent a 'stop' / StopAsyncInteration
|
||||
# to the other side but and possibly has closed the local
|
||||
# feeder mem chan? Do we wait for some kind of ack or just
|
||||
# let this fail silently and bubble up (currently)?
|
||||
|
||||
# XXX: local consumer has closed their side
|
||||
# so cancel the far end streaming task
|
||||
log.warning(f"{send_chan} consumer is already closed")
|
||||
|
@ -508,6 +554,7 @@ class Actor:
|
|||
if cid:
|
||||
# deliver response to local caller/waiter
|
||||
await self._push_result(chan, cid, msg)
|
||||
|
||||
log.debug(
|
||||
f"Waiting on next msg for {chan} from {chan.uid}")
|
||||
continue
|
||||
|
|
|
@ -312,11 +312,20 @@ class Portal:
|
|||
|
||||
ctx = Context(self.channel, cid, _portal=self)
|
||||
try:
|
||||
async with ReceiveMsgStream(ctx, recv_chan, self) as rchan:
|
||||
# deliver receive only stream
|
||||
async with ReceiveMsgStream(ctx, recv_chan) as rchan:
|
||||
self._streams.add(rchan)
|
||||
yield rchan
|
||||
|
||||
finally:
|
||||
|
||||
# cancel the far end task on consumer close
|
||||
# NOTE: this is a special case since we assume that if using
|
||||
# this ``.open_fream_from()`` api, the stream is one a one
|
||||
# time use and we couple the far end tasks's lifetime to
|
||||
# the consumer's scope; we don't ever send a `'stop'`
|
||||
# message right now since there shouldn't be a reason to
|
||||
# stop and restart the stream, right?
|
||||
try:
|
||||
await ctx.cancel()
|
||||
except trio.ClosedResourceError:
|
||||
|
@ -326,16 +335,55 @@ class Portal:
|
|||
|
||||
self._streams.remove(rchan)
|
||||
|
||||
# @asynccontextmanager
|
||||
# async def open_context(
|
||||
# self,
|
||||
# func: Callable,
|
||||
# **kwargs,
|
||||
# ) -> Context:
|
||||
# # TODO
|
||||
# elif resptype == 'context': # context manager style setup/teardown
|
||||
# # TODO likely not here though
|
||||
# raise NotImplementedError
|
||||
@asynccontextmanager
|
||||
async def open_context(
|
||||
self,
|
||||
func: Callable,
|
||||
**kwargs,
|
||||
) -> Context:
|
||||
"""Open an inter-actor task context.
|
||||
|
||||
This is a synchronous API which allows for deterministic
|
||||
setup/teardown of a remote task. The yielded ``Context`` further
|
||||
allows for opening bidirectional streams - see
|
||||
``Context.open_stream()``.
|
||||
|
||||
"""
|
||||
# conduct target func method structural checks
|
||||
if not inspect.iscoroutinefunction(func) and (
|
||||
getattr(func, '_tractor_contex_function', False)
|
||||
):
|
||||
raise TypeError(
|
||||
f'{func} must be an async generator function!')
|
||||
|
||||
fn_mod_path, fn_name = func_deats(func)
|
||||
|
||||
cid, recv_chan, functype, first_msg = await self._submit(
|
||||
fn_mod_path, fn_name, kwargs)
|
||||
|
||||
assert functype == 'context'
|
||||
|
||||
msg = await recv_chan.receive()
|
||||
try:
|
||||
# the "first" value here is delivered by the callee's
|
||||
# ``Context.started()`` call.
|
||||
first = msg['started']
|
||||
|
||||
except KeyError:
|
||||
assert msg.get('cid'), ("Received internal error at context?")
|
||||
|
||||
if msg.get('error'):
|
||||
# raise the error message
|
||||
raise unpack_error(msg, self.channel)
|
||||
else:
|
||||
raise
|
||||
try:
|
||||
ctx = Context(self.channel, cid, _portal=self)
|
||||
yield ctx, first
|
||||
|
||||
finally:
|
||||
await recv_chan.aclose()
|
||||
await ctx.cancel()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,19 +1,195 @@
|
|||
import inspect
|
||||
from contextlib import contextmanager # , asynccontextmanager
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterator, Optional
|
||||
from typing import Any, Iterator, Optional, Callable
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
|
||||
from ._ipc import Channel
|
||||
from ._exceptions import unpack_error
|
||||
from ._state import current_actor
|
||||
from .log import get_logger
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: generic typing like trio's receive channel
|
||||
# but with msgspec messages?
|
||||
# class ReceiveChannel(AsyncResource, Generic[ReceiveType]):
|
||||
|
||||
|
||||
class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||
special behaviour for signalling stream termination across an
|
||||
inter-actor ``Channel``. This is the type returned to a local task
|
||||
which invoked a remote streaming function using `Portal.run()`.
|
||||
|
||||
Termination rules:
|
||||
- if the local task signals stop iteration a cancel signal is
|
||||
relayed to the remote task indicating to stop streaming
|
||||
- if the remote task signals the end of a stream, raise a
|
||||
``StopAsyncIteration`` to terminate the local ``async for``
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
ctx: 'Context', # typing: ignore # noqa
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
self._shielded = False
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
return self._rx_chan.receive_nowait()
|
||||
|
||||
async def receive(self):
|
||||
try:
|
||||
msg = await self._rx_chan.receive()
|
||||
return msg['yield']
|
||||
|
||||
except KeyError:
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), ("Received internal error at portal?")
|
||||
|
||||
# TODO: handle 2 cases with 3.10 match syntax
|
||||
# - 'stop'
|
||||
# - 'error'
|
||||
# possibly just handle msg['stop'] here!
|
||||
|
||||
if msg.get('stop'):
|
||||
log.debug(f"{self} was stopped at remote end")
|
||||
# when the send is closed we assume the stream has
|
||||
# terminated and signal this local iterator to stop
|
||||
await self.aclose()
|
||||
raise trio.EndOfChannel
|
||||
|
||||
# TODO: test that shows stream raising an expected error!!!
|
||||
elif msg.get('error'):
|
||||
# raise the error message
|
||||
raise unpack_error(msg, self._ctx.chan)
|
||||
|
||||
else:
|
||||
raise
|
||||
|
||||
except (trio.ClosedResourceError, StopAsyncIteration):
|
||||
# XXX: this indicates that a `stop` message was
|
||||
# sent by the far side of the underlying channel.
|
||||
# Currently this is triggered by calling ``.aclose()`` on
|
||||
# the send side of the channel inside
|
||||
# ``Actor._push_result()``, but maybe it should be put here?
|
||||
# to avoid exposing the internal mem chan closing mechanism?
|
||||
# in theory we could instead do some flushing of the channel
|
||||
# if needed to ensure all consumers are complete before
|
||||
# triggering closure too early?
|
||||
|
||||
# Locally, we want to close this stream gracefully, by
|
||||
# terminating any local consumers tasks deterministically.
|
||||
# We **don't** want to be closing this send channel and not
|
||||
# relaying a final value to remaining consumers who may not
|
||||
# have been scheduled to receive it yet?
|
||||
|
||||
# lots of testing to do here
|
||||
|
||||
# when the send is closed we assume the stream has
|
||||
# terminated and signal this local iterator to stop
|
||||
await self.aclose()
|
||||
# await self._ctx.send_stop()
|
||||
raise StopAsyncIteration
|
||||
|
||||
except trio.Cancelled:
|
||||
# relay cancels to the remote task
|
||||
await self.aclose()
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def shield(
|
||||
self
|
||||
) -> Iterator['ReceiveMsgStream']: # noqa
|
||||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
self._shielded = False
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
|
||||
"""
|
||||
# TODO: proper adherance to trio's `.aclose()` semantics:
|
||||
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
|
||||
rx_chan = self._rx_chan
|
||||
|
||||
if rx_chan._closed:
|
||||
log.warning(f"{self} is already closed")
|
||||
return
|
||||
|
||||
# TODO: broadcasting to multiple consumers
|
||||
# stats = rx_chan.statistics()
|
||||
# if stats.open_receive_channels > 1:
|
||||
# # if we've been cloned don't kill the stream
|
||||
# log.debug(
|
||||
# "there are still consumers running keeping stream alive")
|
||||
# return
|
||||
|
||||
if self._shielded:
|
||||
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
return
|
||||
|
||||
# NOTE: this is super subtle IPC messaging stuff:
|
||||
# Relay stop iteration to far end **iff** we're
|
||||
# in bidirectional mode. If we're only streaming
|
||||
# *from* one side then that side **won't** have an
|
||||
# entry in `Actor._cids2qs` (maybe it should though?).
|
||||
# So any `yield` or `stop` msgs sent from the caller side
|
||||
# will cause key errors on the callee side since there is
|
||||
# no entry for a local feeder mem chan since the callee task
|
||||
# isn't expecting messages to be sent by the caller.
|
||||
# Thus, we must check that this context DOES NOT
|
||||
# have a portal reference to ensure this is indeed the callee
|
||||
# side and can relay a 'stop'. In the bidirectional case,
|
||||
# `Context.open_stream()` will create the `Actor._cids2qs`
|
||||
# entry from a call to `Actor.get_memchans()`.
|
||||
if not self._ctx._portal:
|
||||
# only for 2 way streams can we can send
|
||||
# stop from the caller side
|
||||
await self._ctx.send_stop()
|
||||
|
||||
# close the local mem chan
|
||||
rx_chan.close()
|
||||
|
||||
# TODO: but make it broadcasting to consumers
|
||||
# def clone(self):
|
||||
# """Clone this receive channel allowing for multi-task
|
||||
# consumption from the same channel.
|
||||
|
||||
# """
|
||||
# return ReceiveStream(
|
||||
# self._cid,
|
||||
# self._rx_chan.clone(),
|
||||
# self._portal,
|
||||
# )
|
||||
|
||||
|
||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||
"""
|
||||
Bidirectional message stream for use within an inter-actor actor
|
||||
``Context```.
|
||||
|
||||
"""
|
||||
async def send(
|
||||
self,
|
||||
data: Any
|
||||
) -> None:
|
||||
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Context:
|
||||
"""An IAC (inter-actor communication) context.
|
||||
|
@ -31,6 +207,10 @@ class Context:
|
|||
chan: Channel
|
||||
cid: str
|
||||
|
||||
# TODO: should we have seperate types for caller vs. callee
|
||||
# side contexts? The caller always opens a portal whereas the callee
|
||||
# is always responding back through a context-stream
|
||||
|
||||
# only set on the caller side
|
||||
_portal: Optional['Portal'] = None # type: ignore # noqa
|
||||
|
||||
|
@ -57,46 +237,97 @@ class Context:
|
|||
timeout quickly to sidestep 2-generals...
|
||||
|
||||
"""
|
||||
assert self._portal, (
|
||||
"No portal found, this is likely a callee side context")
|
||||
if self._portal: # caller side:
|
||||
if not self._portal:
|
||||
raise RuntimeError(
|
||||
"No portal found, this is likely a callee side context"
|
||||
)
|
||||
|
||||
cid = self.cid
|
||||
with trio.move_on_after(0.5) as cs:
|
||||
cs.shield = True
|
||||
log.warning(
|
||||
f"Cancelling stream {cid} to "
|
||||
f"{self._portal.channel.uid}")
|
||||
|
||||
# NOTE: we're telling the far end actor to cancel a task
|
||||
# corresponding to *this actor*. The far end local channel
|
||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
||||
|
||||
if cs.cancelled_caught:
|
||||
# XXX: there's no way to know if the remote task was indeed
|
||||
# cancelled in the case where the connection is broken or
|
||||
# some other network error occurred.
|
||||
if not self._portal.channel.connected():
|
||||
cid = self.cid
|
||||
with trio.move_on_after(0.5) as cs:
|
||||
cs.shield = True
|
||||
log.warning(
|
||||
"May have failed to cancel remote task "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
f"Cancelling stream {cid} to "
|
||||
f"{self._portal.channel.uid}")
|
||||
|
||||
# NOTE: we're telling the far end actor to cancel a task
|
||||
# corresponding to *this actor*. The far end local channel
|
||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
||||
|
||||
if cs.cancelled_caught:
|
||||
# XXX: there's no way to know if the remote task was indeed
|
||||
# cancelled in the case where the connection is broken or
|
||||
# some other network error occurred.
|
||||
# if not self._portal.channel.connected():
|
||||
if not self.chan.connected():
|
||||
log.warning(
|
||||
"May have failed to cancel remote task "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
else:
|
||||
# ensure callee side
|
||||
assert self._cancel_scope
|
||||
# TODO: should we have an explicit cancel message
|
||||
# or is relaying the local `trio.Cancelled` as an
|
||||
# {'error': trio.Cancelled, cid: "blah"} enough?
|
||||
# This probably gets into the discussion in
|
||||
# https://github.com/goodboy/tractor/issues/36
|
||||
self._cancel_scope.cancel()
|
||||
|
||||
# TODO: do we need a restart api?
|
||||
# async def restart(self) -> None:
|
||||
# # TODO
|
||||
# pass
|
||||
|
||||
# @asynccontextmanager
|
||||
# async def open_stream(
|
||||
# self,
|
||||
# ) -> AsyncContextManager:
|
||||
# # TODO
|
||||
# pass
|
||||
@asynccontextmanager
|
||||
async def open_stream(
|
||||
self,
|
||||
) -> MsgStream:
|
||||
# TODO
|
||||
|
||||
actor = current_actor()
|
||||
|
||||
# here we create a mem chan that corresponds to the
|
||||
# far end caller / callee.
|
||||
|
||||
# NOTE: in one way streaming this only happens on the
|
||||
# caller side inside `Actor.send_cmd()` so if you try
|
||||
# to send a stop from the caller to the callee in the
|
||||
# single-direction-stream case you'll get a lookup error
|
||||
# currently.
|
||||
_, recv_chan = actor.get_memchans(
|
||||
self.chan.uid,
|
||||
self.cid
|
||||
)
|
||||
|
||||
async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan:
|
||||
|
||||
if self._portal:
|
||||
self._portal._streams.add(rchan)
|
||||
|
||||
try:
|
||||
yield rchan
|
||||
|
||||
finally:
|
||||
await self.send_stop()
|
||||
if self._portal:
|
||||
self._portal._streams.add(rchan)
|
||||
|
||||
async def started(self, value: Any) -> None:
|
||||
|
||||
if self._portal:
|
||||
raise RuntimeError(
|
||||
f"Caller side context {self} can not call started!")
|
||||
|
||||
await self.chan.send({'started': value, 'cid': self.cid})
|
||||
|
||||
|
||||
def stream(func):
|
||||
def stream(func: Callable) -> Callable:
|
||||
"""Mark an async function as a streaming routine with ``@stream``.
|
||||
|
||||
"""
|
||||
# annotate
|
||||
func._tractor_stream_function = True
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
if 'stream' not in params and 'ctx' in params:
|
||||
|
@ -114,147 +345,24 @@ def stream(func):
|
|||
):
|
||||
raise TypeError(
|
||||
"The first argument to the stream function "
|
||||
f"{func.__name__} must be `ctx: tractor.Context`"
|
||||
f"{func.__name__} must be `ctx: tractor.Context` "
|
||||
"(Or ``to_trio`` if using ``asyncio`` in guest mode)."
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||
special behaviour for signalling stream termination across an
|
||||
inter-actor ``Channel``. This is the type returned to a local task
|
||||
which invoked a remote streaming function using `Portal.run()`.
|
||||
|
||||
Termination rules:
|
||||
- if the local task signals stop iteration a cancel signal is
|
||||
relayed to the remote task indicating to stop streaming
|
||||
- if the remote task signals the end of a stream, raise a
|
||||
``StopAsyncIteration`` to terminate the local ``async for``
|
||||
def context(func: Callable) -> Callable:
|
||||
"""Mark an async function as a streaming routine with ``@context``.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
ctx: Context,
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
portal: 'Portal', # type: ignore # noqa
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
self._portal = portal
|
||||
self._shielded = False
|
||||
# annotate
|
||||
func._tractor_context_function = True
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
return self._rx_chan.receive_nowait()
|
||||
|
||||
async def receive(self):
|
||||
try:
|
||||
msg = await self._rx_chan.receive()
|
||||
return msg['yield']
|
||||
|
||||
except KeyError:
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), ("Received internal error at portal?")
|
||||
|
||||
# TODO: handle 2 cases with 3.10 match syntax
|
||||
# - 'stop'
|
||||
# - 'error'
|
||||
# possibly just handle msg['stop'] here!
|
||||
|
||||
# TODO: test that shows stream raising an expected error!!!
|
||||
if msg.get('error'):
|
||||
# raise the error message
|
||||
raise unpack_error(msg, self._portal.channel)
|
||||
|
||||
except (trio.ClosedResourceError, StopAsyncIteration):
|
||||
# XXX: this indicates that a `stop` message was
|
||||
# sent by the far side of the underlying channel.
|
||||
# Currently this is triggered by calling ``.aclose()`` on
|
||||
# the send side of the channel inside
|
||||
# ``Actor._push_result()``, but maybe it should be put here?
|
||||
# to avoid exposing the internal mem chan closing mechanism?
|
||||
# in theory we could instead do some flushing of the channel
|
||||
# if needed to ensure all consumers are complete before
|
||||
# triggering closure too early?
|
||||
|
||||
# Locally, we want to close this stream gracefully, by
|
||||
# terminating any local consumers tasks deterministically.
|
||||
# We **don't** want to be closing this send channel and not
|
||||
# relaying a final value to remaining consumers who may not
|
||||
# have been scheduled to receive it yet?
|
||||
|
||||
# lots of testing to do here
|
||||
|
||||
# when the send is closed we assume the stream has
|
||||
# terminated and signal this local iterator to stop
|
||||
await self.aclose()
|
||||
raise StopAsyncIteration
|
||||
|
||||
except trio.Cancelled:
|
||||
# relay cancels to the remote task
|
||||
await self.aclose()
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def shield(
|
||||
self
|
||||
) -> Iterator['ReceiveMsgStream']: # noqa
|
||||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
self._shielded = False
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
"""
|
||||
rx_chan = self._rx_chan
|
||||
|
||||
if rx_chan._closed:
|
||||
log.warning(f"{self} is already closed")
|
||||
return
|
||||
|
||||
# stats = rx_chan.statistics()
|
||||
# if stats.open_receive_channels > 1:
|
||||
# # if we've been cloned don't kill the stream
|
||||
# log.debug(
|
||||
# "there are still consumers running keeping stream alive")
|
||||
# return
|
||||
|
||||
if self._shielded:
|
||||
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
return
|
||||
|
||||
# close the local mem chan
|
||||
rx_chan.close()
|
||||
|
||||
# cancel surrounding IPC context
|
||||
await self._ctx.cancel()
|
||||
|
||||
# TODO: but make it broadcasting to consumers
|
||||
# def clone(self):
|
||||
# """Clone this receive channel allowing for multi-task
|
||||
# consumption from the same channel.
|
||||
|
||||
# """
|
||||
# return ReceiveStream(
|
||||
# self._cid,
|
||||
# self._rx_chan.clone(),
|
||||
# self._portal,
|
||||
# )
|
||||
|
||||
|
||||
# class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||
# """
|
||||
# Bidirectional message stream for use within an inter-actor actor
|
||||
# ``Context```.
|
||||
|
||||
# """
|
||||
# async def send(
|
||||
# self,
|
||||
# data: Any
|
||||
# ) -> None:
|
||||
# await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
if 'ctx' not in params:
|
||||
raise TypeError(
|
||||
"The first argument to the context function "
|
||||
f"{func.__name__} must be `ctx: tractor.Context`"
|
||||
)
|
||||
return func
|
||||
|
|
Loading…
Reference in New Issue