forked from goodboy/tractor
Aggregate and organize streaming components
Move receive stream into streaming modules and rebrand as a "message stream". Factor out cancellation mechanics in `.aclose()` into the `Context` type which will soon provide the api for for cancelling portal invocations. Comment-stage a few methods on both types in anticipation of a new bi-directional streaming api. Add a `MsgStream` bidirectional channel type which will be the eventual type yielded from `Context.open_stream()`. Adjust the response/dialog types to be the set `{'asyncfun', 'asyncgen', 'context'}`. OH, and add async func checking in `Portal.run()` to catch and error on sync funcs early.stream_contexts
parent
a5a88e2f64
commit
7f38b7225d
|
@ -4,10 +4,9 @@ Portal api
|
|||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Tuple, Any, Dict, Optional, Set, Iterator
|
||||
from typing import Tuple, Any, Dict, Optional, Set
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
|
@ -17,9 +16,10 @@ from ._state import current_actor
|
|||
from ._ipc import Channel
|
||||
from .log import get_logger
|
||||
from ._exceptions import unpack_error, NoResult, RemoteActorError
|
||||
from ._streaming import Context, ReceiveMsgStream
|
||||
|
||||
|
||||
log = get_logger('tractor')
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -39,113 +39,23 @@ async def maybe_open_nursery(
|
|||
yield nursery
|
||||
|
||||
|
||||
class ReceiveStream(trio.abc.ReceiveChannel):
|
||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||
special behaviour for signalling stream termination across an
|
||||
inter-actor ``Channel``. This is the type returned to a local task
|
||||
which invoked a remote streaming function using `Portal.run()`.
|
||||
|
||||
Termination rules:
|
||||
- if the local task signals stop iteration a cancel signal is
|
||||
relayed to the remote task indicating to stop streaming
|
||||
- if the remote task signals the end of a stream, raise a
|
||||
``StopAsyncIteration`` to terminate the local ``async for``
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
cid: str,
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
portal: 'Portal',
|
||||
) -> None:
|
||||
self._cid = cid
|
||||
self._rx_chan = rx_chan
|
||||
self._portal = portal
|
||||
self._shielded = False
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
return self._rx_chan.receive_nowait()
|
||||
|
||||
async def receive(self):
|
||||
try:
|
||||
msg = await self._rx_chan.receive()
|
||||
return msg['yield']
|
||||
except trio.ClosedResourceError:
|
||||
# when the send is closed we assume the stream has
|
||||
# terminated and signal this local iterator to stop
|
||||
await self.aclose()
|
||||
raise StopAsyncIteration
|
||||
except trio.Cancelled:
|
||||
# relay cancels to the remote task
|
||||
await self.aclose()
|
||||
raise
|
||||
except KeyError:
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), (
|
||||
"Received internal error at portal?")
|
||||
raise unpack_error(msg, self._portal.channel)
|
||||
|
||||
@contextmanager
|
||||
def shield(
|
||||
self
|
||||
) -> Iterator['ReceiveStream']: # noqa
|
||||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
self._shielded = False
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
"""
|
||||
if self._rx_chan._closed:
|
||||
log.warning(f"{self} is already closed")
|
||||
return
|
||||
|
||||
if self._shielded:
|
||||
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
return
|
||||
|
||||
cid = self._cid
|
||||
with trio.move_on_after(0.5) as cs:
|
||||
cs.shield = True
|
||||
log.warning(
|
||||
f"Cancelling stream {cid} to "
|
||||
f"{self._portal.channel.uid}")
|
||||
|
||||
# NOTE: we're telling the far end actor to cancel a task
|
||||
# corresponding to *this actor*. The far end local channel
|
||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
||||
|
||||
if cs.cancelled_caught:
|
||||
# XXX: there's no way to know if the remote task was indeed
|
||||
# cancelled in the case where the connection is broken or
|
||||
# some other network error occurred.
|
||||
if not self._portal.channel.connected():
|
||||
log.warning(
|
||||
"May have failed to cancel remote task "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
|
||||
with trio.CancelScope(shield=True):
|
||||
await self._rx_chan.aclose()
|
||||
|
||||
def clone(self):
|
||||
return self
|
||||
|
||||
|
||||
class Portal:
|
||||
"""A 'portal' to a(n) (remote) ``Actor``.
|
||||
|
||||
Allows for invoking remote routines and receiving results through an
|
||||
underlying ``tractor.Channel`` as though the remote (async)
|
||||
function / generator was invoked locally.
|
||||
A portal is "opened" (and eventually closed) by one side of an
|
||||
inter-actor communication context. The side which opens the portal
|
||||
is equivalent to a "caller" in function parlance and usually is
|
||||
either the called actor's parent (in process tree hierarchy terms)
|
||||
or a client interested in scheduling work to be done remotely in a
|
||||
far process.
|
||||
|
||||
The portal api allows the "caller" actor to invoke remote routines
|
||||
and receive results through an underlying ``tractor.Channel`` as
|
||||
though the remote (async) function / generator was called locally.
|
||||
It may be thought of loosely as an RPC api where native Python
|
||||
function calling semantics are supported transparently; hence it is
|
||||
like having a "portal" between the seperate actor memory spaces.
|
||||
|
||||
Think of this like a native async IPC API.
|
||||
"""
|
||||
def __init__(self, channel: Channel) -> None:
|
||||
self.channel = channel
|
||||
|
@ -157,7 +67,7 @@ class Portal:
|
|||
self._expect_result: Optional[
|
||||
Tuple[str, Any, str, Dict[str, Any]]
|
||||
] = None
|
||||
self._streams: Set[ReceiveStream] = set()
|
||||
self._streams: Set[ReceiveMsgStream] = set()
|
||||
self.actor = current_actor()
|
||||
|
||||
async def _submit(
|
||||
|
@ -182,55 +92,19 @@ class Portal:
|
|||
first_msg = await recv_chan.receive()
|
||||
functype = first_msg.get('functype')
|
||||
|
||||
if functype == 'asyncfunc':
|
||||
resp_type = 'return'
|
||||
elif functype == 'asyncgen':
|
||||
resp_type = 'yield'
|
||||
elif 'error' in first_msg:
|
||||
if 'error' in first_msg:
|
||||
raise unpack_error(first_msg, self.channel)
|
||||
else:
|
||||
|
||||
elif functype not in ('asyncfunc', 'asyncgen', 'context'):
|
||||
raise ValueError(f"{first_msg} is an invalid response packet?")
|
||||
|
||||
return cid, recv_chan, resp_type, first_msg
|
||||
return cid, recv_chan, functype, first_msg
|
||||
|
||||
async def _submit_for_result(self, ns: str, func: str, **kwargs) -> None:
|
||||
assert self._expect_result is None, \
|
||||
"A pending main result has already been submitted"
|
||||
self._expect_result = await self._submit(ns, func, kwargs)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
func_or_ns: str,
|
||||
fn_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""Submit a remote function to be scheduled and run by actor, in
|
||||
a new task, wrap and return its (stream of) result(s).
|
||||
|
||||
This is a blocking call and returns either a value from the
|
||||
remote rpc task or a local async generator instance.
|
||||
"""
|
||||
if isinstance(func_or_ns, str):
|
||||
warnings.warn(
|
||||
"`Portal.run(namespace: str, funcname: str)` is now"
|
||||
"deprecated, pass a function reference directly instead\n"
|
||||
"If you still want to run a remote function by name use"
|
||||
"`Portal.run_from_ns()`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
fn_mod_path = func_or_ns
|
||||
assert isinstance(fn_name, str)
|
||||
|
||||
else: # function reference was passed directly
|
||||
fn = func_or_ns
|
||||
fn_mod_path = fn.__module__
|
||||
fn_name = fn.__name__
|
||||
|
||||
return await self._return_from_resptype(
|
||||
*(await self._submit(fn_mod_path, fn_name, kwargs))
|
||||
)
|
||||
|
||||
async def run_from_ns(
|
||||
self,
|
||||
namespace_path: str,
|
||||
|
@ -260,15 +134,19 @@ class Portal:
|
|||
resptype: str,
|
||||
first_msg: dict
|
||||
) -> Any:
|
||||
# TODO: not this needs some serious work and thinking about how
|
||||
# to make async-generators the fundamental IPC API over channels!
|
||||
# (think `yield from`, `gen.send()`, and functional reactive stuff)
|
||||
if resptype == 'yield': # stream response
|
||||
rchan = ReceiveStream(cid, recv_chan, self)
|
||||
|
||||
# receive only stream
|
||||
if resptype == 'asyncgen':
|
||||
ctx = Context(self.channel, cid, _portal=self)
|
||||
rchan = ReceiveMsgStream(ctx, recv_chan, self)
|
||||
self._streams.add(rchan)
|
||||
return rchan
|
||||
|
||||
elif resptype == 'return': # single response
|
||||
elif resptype == 'context': # context manager style setup/teardown
|
||||
# TODO likely not here though
|
||||
raise NotImplementedError
|
||||
|
||||
elif resptype == 'asyncfunc': # single response
|
||||
msg = await recv_chan.receive()
|
||||
try:
|
||||
return msg['return']
|
||||
|
@ -369,6 +247,65 @@ class Portal:
|
|||
f"{self.channel} for {self.channel.uid} was already closed?")
|
||||
return False
|
||||
|
||||
async def run(
|
||||
self,
|
||||
func: str,
|
||||
fn_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""Submit a remote function to be scheduled and run by actor, in
|
||||
a new task, wrap and return its (stream of) result(s).
|
||||
|
||||
This is a blocking call and returns either a value from the
|
||||
remote rpc task or a local async generator instance.
|
||||
"""
|
||||
if isinstance(func, str):
|
||||
warnings.warn(
|
||||
"`Portal.run(namespace: str, funcname: str)` is now"
|
||||
"deprecated, pass a function reference directly instead\n"
|
||||
"If you still want to run a remote function by name use"
|
||||
"`Portal.run_from_ns()`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
fn_mod_path = func
|
||||
assert isinstance(fn_name, str)
|
||||
|
||||
else: # function reference was passed directly
|
||||
|
||||
# TODO: ensure async
|
||||
if not (
|
||||
inspect.isasyncgenfunction(func) or
|
||||
inspect.iscoroutinefunction(func)
|
||||
):
|
||||
raise TypeError(f'{func} must be an async function!')
|
||||
|
||||
fn = func
|
||||
fn_mod_path = fn.__module__
|
||||
fn_name = fn.__name__
|
||||
|
||||
return await self._return_from_resptype(
|
||||
*(await self._submit(fn_mod_path, fn_name, kwargs))
|
||||
)
|
||||
|
||||
# @asynccontextmanager
|
||||
# async def open_stream_from(
|
||||
# self,
|
||||
# async_gen: 'AsyncGeneratorFunction',
|
||||
# **kwargs,
|
||||
# ) -> ReceiveMsgStream:
|
||||
# # TODO
|
||||
# pass
|
||||
|
||||
# @asynccontextmanager
|
||||
# async def open_context(
|
||||
# self,
|
||||
# func: Callable,
|
||||
# **kwargs,
|
||||
# ) -> Context:
|
||||
# # TODO
|
||||
# pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalPortal:
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import inspect
|
||||
from contextvars import ContextVar
|
||||
from contextlib import contextmanager # , asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Iterator, Optional
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
|
||||
from ._ipc import Channel
|
||||
from ._exceptions import unpack_error
|
||||
from .log import get_logger
|
||||
|
||||
|
||||
_context: ContextVar['Context'] = ContextVar('context')
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -18,22 +21,73 @@ class Context:
|
|||
Allows maintaining task or protocol specific state between communicating
|
||||
actors. A unique context is created on the receiving end for every request
|
||||
to a remote actor.
|
||||
|
||||
A context can be cancelled and (eventually) restarted from
|
||||
either side of the underlying IPC channel.
|
||||
|
||||
A context can be used to open task oriented message streams.
|
||||
|
||||
"""
|
||||
chan: Channel
|
||||
cid: str
|
||||
cancel_scope: trio.CancelScope
|
||||
|
||||
# only set on the caller side
|
||||
_portal: Optional['Portal'] = None # type: ignore # noqa
|
||||
|
||||
# only set on the callee side
|
||||
_cancel_scope: Optional[trio.CancelScope] = None
|
||||
|
||||
async def send_yield(self, data: Any) -> None:
|
||||
|
||||
warnings.warn(
|
||||
"`Context.send_yield()` is now deprecated. "
|
||||
"Use ``MessageStream.send()``. ",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
await self.chan.send({'yield': data, 'cid': self.cid})
|
||||
|
||||
async def send_stop(self) -> None:
|
||||
await self.chan.send({'stop': True, 'cid': self.cid})
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel this inter-actor-task context.
|
||||
|
||||
def current_context():
|
||||
"""Get the current task's context instance.
|
||||
"""
|
||||
return _context.get()
|
||||
Request that the far side cancel it's current linked context,
|
||||
timeout quickly to sidestep 2-generals...
|
||||
|
||||
"""
|
||||
cid = self.cid
|
||||
with trio.move_on_after(0.5) as cs:
|
||||
cs.shield = True
|
||||
log.warning(
|
||||
f"Cancelling stream {cid} to "
|
||||
f"{self._portal.channel.uid}")
|
||||
|
||||
# NOTE: we're telling the far end actor to cancel a task
|
||||
# corresponding to *this actor*. The far end local channel
|
||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
||||
|
||||
if cs.cancelled_caught:
|
||||
# XXX: there's no way to know if the remote task was indeed
|
||||
# cancelled in the case where the connection is broken or
|
||||
# some other network error occurred.
|
||||
if not self._portal.channel.connected():
|
||||
log.warning(
|
||||
"May have failed to cancel remote task "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
|
||||
# async def restart(self) -> None:
|
||||
# # TODO
|
||||
# pass
|
||||
|
||||
# @asynccontextmanager
|
||||
# async def open_stream(
|
||||
# self,
|
||||
# ) -> AsyncContextManager:
|
||||
# # TODO
|
||||
# pass
|
||||
|
||||
|
||||
def stream(func):
|
||||
|
@ -47,3 +101,146 @@ def stream(func):
|
|||
f"{func.__name__} must be `ctx: tractor.Context`"
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||
special behaviour for signalling stream termination across an
|
||||
inter-actor ``Channel``. This is the type returned to a local task
|
||||
which invoked a remote streaming function using `Portal.run()`.
|
||||
|
||||
Termination rules:
|
||||
- if the local task signals stop iteration a cancel signal is
|
||||
relayed to the remote task indicating to stop streaming
|
||||
- if the remote task signals the end of a stream, raise a
|
||||
``StopAsyncIteration`` to terminate the local ``async for``
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
ctx: Context,
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
portal: 'Portal', # noqa
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
self._portal = portal
|
||||
# self._chan = portal.channel
|
||||
self._shielded = False
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
return self._rx_chan.receive_nowait()
|
||||
|
||||
async def receive(self):
|
||||
try:
|
||||
msg = await self._rx_chan.receive()
|
||||
return msg['yield']
|
||||
# return msg['yield']
|
||||
|
||||
except KeyError:
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), ("Received internal error at portal?")
|
||||
|
||||
# TODO: handle 2 cases with 3.10 match syntax
|
||||
# - 'stop'
|
||||
# - 'error'
|
||||
# possibly just handle msg['stop'] here!
|
||||
|
||||
# TODO: test that shows stream raising an expected error!!!
|
||||
if msg.get('error'):
|
||||
# raise the error message
|
||||
raise unpack_error(msg, self._portal.channel)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
# XXX: this indicates that a `stop` message was
|
||||
# sent by the far side of the underlying channel.
|
||||
# Currently this is triggered by calling ``.aclose()`` on
|
||||
# the send side of the channel inside
|
||||
# ``Actor._push_result()``, but maybe it should be put here?
|
||||
# to avoid exposing the internal mem chan closing mechanism?
|
||||
# in theory we could instead do some flushing of the channel
|
||||
# if needed to ensure all consumers are complete before
|
||||
# triggering closure too early?
|
||||
|
||||
# Locally, we want to close this stream gracefully, by
|
||||
# terminating any local consumers tasks deterministically.
|
||||
# We **don't** want to be closing this send channel and not
|
||||
# relaying a final value to remaining consumers who may not
|
||||
# have been scheduled to receive it yet?
|
||||
|
||||
# lots of testing to do here
|
||||
|
||||
# when the send is closed we assume the stream has
|
||||
# terminated and signal this local iterator to stop
|
||||
await self.aclose()
|
||||
raise StopAsyncIteration
|
||||
|
||||
except trio.Cancelled:
|
||||
# relay cancels to the remote task
|
||||
await self.aclose()
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def shield(
|
||||
self
|
||||
) -> Iterator['ReceiveStream']: # noqa
|
||||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
self._shielded = False
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
"""
|
||||
rx_chan = self._rx_chan
|
||||
|
||||
if rx_chan._closed:
|
||||
log.warning(f"{self} is already closed")
|
||||
return
|
||||
|
||||
# stats = rx_chan.statistics()
|
||||
# if stats.open_receive_channels > 1:
|
||||
# # if we've been cloned don't kill the stream
|
||||
# log.debug(
|
||||
# "there are still consumers running keeping stream alive")
|
||||
# return
|
||||
|
||||
if self._shielded:
|
||||
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
return
|
||||
|
||||
# close the local mem chan
|
||||
rx_chan.close()
|
||||
|
||||
# cancel surrounding IPC context
|
||||
await self._ctx.cancel()
|
||||
|
||||
# TODO: but make it broadcasting to consumers
|
||||
# def clone(self):
|
||||
# """Clone this receive channel allowing for multi-task
|
||||
# consumption from the same channel.
|
||||
|
||||
# """
|
||||
# return ReceiveStream(
|
||||
# self._cid,
|
||||
# self._rx_chan.clone(),
|
||||
# self._portal,
|
||||
# )
|
||||
|
||||
|
||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||
"""
|
||||
Bidirectional message stream for use within an inter-actor actor
|
||||
``Context```.
|
||||
|
||||
"""
|
||||
async def send(
|
||||
self,
|
||||
data: Any
|
||||
) -> None:
|
||||
await self._chan.send({'yield': data, 'cid': self._cid})
|
||||
|
|
Loading…
Reference in New Issue