diff --git a/tractor/__init__.py b/tractor/__init__.py index d14717b..fb9b095 100644 --- a/tractor/__init__.py +++ b/tractor/__init__.py @@ -11,7 +11,8 @@ import trio # type: ignore from trio import MultiError from . import log -from ._ipc import _connect_chan, Channel, Context +from ._ipc import _connect_chan, Channel +from ._streaming import Context, stream from ._discovery import get_arbiter, find_actor, wait_for_actor from ._actor import Actor, _start_actor, Arbiter from ._trionics import open_nursery @@ -29,6 +30,7 @@ __all__ = [ 'wait_for_actor', 'Channel', 'Context', + 'stream', 'MultiError', 'RemoteActorError', 'ModuleNotExposed', diff --git a/tractor/_actor.py b/tractor/_actor.py index 496c68a..52b301f 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -13,7 +13,8 @@ from typing import Dict, List, Tuple, Any, Optional import trio # type: ignore from async_generator import aclosing -from ._ipc import Channel, Context +from ._ipc import Channel +from ._streaming import Context, _context from .log import get_console_log, get_logger from ._exceptions import ( pack_error, @@ -47,7 +48,13 @@ async def _invoke( cs = None cancel_scope = trio.CancelScope() ctx = Context(chan, cid, cancel_scope) - if 'ctx' in sig.parameters: + _context.set(ctx) + if getattr(func, '_tractor_stream_function', False): + if 'ctx' not in sig.parameters: + raise TypeError( + "The first argument to the stream function " + f"{func.__name__} must be `ctx: tractor.Context`" + ) kwargs['ctx'] = ctx # TODO: eventually we want to be more stringent # about what is considered a far-end async-generator. diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 5acb79a..94f978f 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -1,7 +1,6 @@ """ Inter-process comms abstractions """ -from dataclasses import dataclass import typing from typing import Any, Tuple, Optional @@ -205,25 +204,6 @@ class Channel: return self.msgstream.connected() if self.msgstream else False -@dataclass(frozen=True) -class Context: - """An IAC (inter-actor communication) context. - - Allows maintaining task or protocol specific state between communicating - actors. A unique context is created on the receiving end for every request - to a remote actor. - """ - chan: Channel - cid: str - cancel_scope: trio.CancelScope - - async def send_yield(self, data: Any) -> None: - await self.chan.send({'yield': data, 'cid': self.cid}) - - async def send_stop(self) -> None: - await self.chan.send({'stop': True, 'cid': self.cid}) - - @asynccontextmanager async def _connect_chan( host: str, port: int diff --git a/tractor/_streaming.py b/tractor/_streaming.py new file mode 100644 index 0000000..0ccf5fc --- /dev/null +++ b/tractor/_streaming.py @@ -0,0 +1,43 @@ +import contextvars +from dataclasses import dataclass +from typing import Any + +import trio + +from ._ipc import Channel + + +_context = contextvars.ContextVar('context') + + +@dataclass(frozen=True) +class Context: + """An IAC (inter-actor communication) context. + + Allows maintaining task or protocol specific state between communicating + actors. A unique context is created on the receiving end for every request + to a remote actor. + """ + chan: Channel + cid: str + cancel_scope: trio.CancelScope + + async def send_yield(self, data: Any) -> None: + await self.chan.send({'yield': data, 'cid': self.cid}) + + async def send_stop(self) -> None: + await self.chan.send({'stop': True, 'cid': self.cid}) + + +def current_context(): + """Get the current streaming task's context instance. + + """ + return _context.get() + + +def stream(func): + """Mark an async function as a streaming routine. + """ + func._tractor_stream_function = True + return func diff --git a/tractor/msg.py b/tractor/msg.py index 73e39d5..59842ef 100644 --- a/tractor/msg.py +++ b/tractor/msg.py @@ -12,7 +12,7 @@ import wrapt from .log import get_logger from . import current_actor -from ._ipc import Context +from ._streaming import Context, stream __all__ = ['pub'] @@ -261,4 +261,4 @@ def pub( "`get_topics` argument" ) - return wrapper(wrapped) + return wrapper(stream(wrapped))