Add initial bi-directional streaming
This mostly adds the api described in https://github.com/goodboy/tractor/issues/53#issuecomment-806258798 The first draft summary: - formalize bidir steaming using the `trio.Channel` style interface which we derive as a `MsgStream` type. - add `Portal.open_context()` which provides a `trio.Nursery.start()` remote task invocation style for setting up and tearing down tasks contexts in remote actors. - add a distinct `'started'` message to the ipc protocol to facilitate `Context.start()` with a first return value. - for our `ReceiveMsgStream` type, don't cancel the remote task in `.aclose()`; this is now done explicitly by the surrounding `Context` usage: `Context.cancel()`. - streams in either direction still use a `'yield'` message keeping the proto mostly symmetric without having to worry about which side is the caller / portal opener. - subtlety: only allow sending a `'stop'` message during a 2-way streaming context from `ReceiveStream.aclose()`, detailed comment with explanation is included. Relates to #53wip_fix_asyncio_gen_streaming
							parent
							
								
									f48548ab94
								
							
						
					
					
						commit
						15fa777ddf
					
				|  | @ -14,6 +14,7 @@ from types import ModuleType | |||
| import sys | ||||
| import os | ||||
| from contextlib import ExitStack | ||||
| import warnings | ||||
| 
 | ||||
| import trio  # type: ignore | ||||
| from trio_typing import TaskStatus | ||||
|  | @ -57,13 +58,37 @@ async def _invoke( | |||
|     treat_as_gen = False | ||||
|     cs = None | ||||
|     cancel_scope = trio.CancelScope() | ||||
|     ctx = Context(chan, cid, cancel_scope) | ||||
|     ctx = Context(chan, cid, _cancel_scope=cancel_scope) | ||||
|     context = False | ||||
| 
 | ||||
|     if getattr(func, '_tractor_stream_function', False): | ||||
|         # handle decorated ``@tractor.stream`` async functions | ||||
|         sig = inspect.signature(func) | ||||
|         params = sig.parameters | ||||
| 
 | ||||
|         # compat with old api | ||||
|         kwargs['ctx'] = ctx | ||||
| 
 | ||||
|         if 'ctx' in params: | ||||
|             warnings.warn( | ||||
|                 "`@tractor.stream decorated funcs should now declare " | ||||
|                 "a `stream`  arg, `ctx` is now designated for use with " | ||||
|                 "@tractor.context", | ||||
|                 DeprecationWarning, | ||||
|                 stacklevel=2, | ||||
|             ) | ||||
| 
 | ||||
|         elif 'stream' in params: | ||||
|             assert 'stream' in params | ||||
|             kwargs['stream'] = ctx | ||||
| 
 | ||||
|         treat_as_gen = True | ||||
| 
 | ||||
|     elif getattr(func, '_tractor_context_function', False): | ||||
|         # handle decorated ``@tractor.context`` async function | ||||
|         kwargs['ctx'] = ctx | ||||
|         context = True | ||||
| 
 | ||||
|     # errors raised inside this block are propgated back to caller | ||||
|     try: | ||||
|         if not ( | ||||
|  | @ -101,26 +126,41 @@ async def _invoke( | |||
|             # `StopAsyncIteration` system here for returning a final | ||||
|             # value if desired | ||||
|             await chan.send({'stop': True, 'cid': cid}) | ||||
| 
 | ||||
|         # one way @stream func that gets treated like an async gen | ||||
|         elif treat_as_gen: | ||||
|             await chan.send({'functype': 'asyncgen', 'cid': cid}) | ||||
|             # XXX: the async-func may spawn further tasks which push | ||||
|             # back values like an async-generator would but must | ||||
|             # manualy construct the response dict-packet-responses as | ||||
|             # above | ||||
|             with cancel_scope as cs: | ||||
|                 task_status.started(cs) | ||||
|                 await coro | ||||
| 
 | ||||
|             if not cs.cancelled_caught: | ||||
|                 # task was not cancelled so we can instruct the | ||||
|                 # far end async gen to tear down | ||||
|                 await chan.send({'stop': True, 'cid': cid}) | ||||
| 
 | ||||
|         elif context: | ||||
|             # context func with support for bi-dir streaming | ||||
|             await chan.send({'functype': 'context', 'cid': cid}) | ||||
| 
 | ||||
|             with cancel_scope as cs: | ||||
|                 task_status.started(cs) | ||||
|                 await chan.send({'return': await coro, 'cid': cid}) | ||||
| 
 | ||||
|             # if cs.cancelled_caught: | ||||
|             #     # task was cancelled so relay to the cancel to caller | ||||
|             #     await chan.send({'return': await coro, 'cid': cid}) | ||||
| 
 | ||||
|         else: | ||||
|             if treat_as_gen: | ||||
|                 await chan.send({'functype': 'asyncgen', 'cid': cid}) | ||||
|                 # XXX: the async-func may spawn further tasks which push | ||||
|                 # back values like an async-generator would but must | ||||
|                 # manualy construct the response dict-packet-responses as | ||||
|                 # above | ||||
|                 with cancel_scope as cs: | ||||
|                     task_status.started(cs) | ||||
|                     await coro | ||||
|                 if not cs.cancelled_caught: | ||||
|                     # task was not cancelled so we can instruct the | ||||
|                     # far end async gen to tear down | ||||
|                     await chan.send({'stop': True, 'cid': cid}) | ||||
|             else: | ||||
|                 # regular async function | ||||
|                 await chan.send({'functype': 'asyncfunc', 'cid': cid}) | ||||
|                 with cancel_scope as cs: | ||||
|                     task_status.started(cs) | ||||
|                     await chan.send({'return': await coro, 'cid': cid}) | ||||
|             # regular async function | ||||
|             await chan.send({'functype': 'asyncfunc', 'cid': cid}) | ||||
|             with cancel_scope as cs: | ||||
|                 task_status.started(cs) | ||||
|                 await chan.send({'return': await coro, 'cid': cid}) | ||||
| 
 | ||||
|     except (Exception, trio.MultiError) as err: | ||||
| 
 | ||||
|  | @ -404,10 +444,10 @@ class Actor: | |||
|         send_chan, recv_chan = self._cids2qs[(actorid, cid)] | ||||
|         assert send_chan.cid == cid  # type: ignore | ||||
| 
 | ||||
|         if 'stop' in msg: | ||||
|             log.debug(f"{send_chan} was terminated at remote end") | ||||
|             # indicate to consumer that far end has stopped | ||||
|             return await send_chan.aclose() | ||||
|         # if 'stop' in msg: | ||||
|         #     log.debug(f"{send_chan} was terminated at remote end") | ||||
|         #     # indicate to consumer that far end has stopped | ||||
|         #     return await send_chan.aclose() | ||||
| 
 | ||||
|         try: | ||||
|             log.debug(f"Delivering {msg} from {actorid} to caller {cid}") | ||||
|  | @ -415,6 +455,12 @@ class Actor: | |||
|             await send_chan.send(msg) | ||||
| 
 | ||||
|         except trio.BrokenResourceError: | ||||
|             # TODO: what is the right way to handle the case where the | ||||
|             # local task has already sent a 'stop' / StopAsyncInteration | ||||
|             # to the other side but and possibly has closed the local | ||||
|             # feeder mem chan? Do we wait for some kind of ack or just | ||||
|             # let this fail silently and bubble up (currently)? | ||||
| 
 | ||||
|             # XXX: local consumer has closed their side | ||||
|             # so cancel the far end streaming task | ||||
|             log.warning(f"{send_chan} consumer is already closed") | ||||
|  | @ -494,6 +540,7 @@ class Actor: | |||
|                     if cid: | ||||
|                         # deliver response to local caller/waiter | ||||
|                         await self._push_result(chan, cid, msg) | ||||
| 
 | ||||
|                         log.debug( | ||||
|                             f"Waiting on next msg for {chan} from {chan.uid}") | ||||
|                         continue | ||||
|  |  | |||
|  | @ -312,11 +312,20 @@ class Portal: | |||
| 
 | ||||
|         ctx = Context(self.channel, cid, _portal=self) | ||||
|         try: | ||||
|             async with ReceiveMsgStream(ctx, recv_chan, self) as rchan: | ||||
|             # deliver receive only stream | ||||
|             async with ReceiveMsgStream(ctx, recv_chan) as rchan: | ||||
|                 self._streams.add(rchan) | ||||
|                 yield rchan | ||||
| 
 | ||||
|         finally: | ||||
| 
 | ||||
|             # cancel the far end task on consumer close | ||||
|             # NOTE: this is a special case since we assume that if using | ||||
|             # this ``.open_fream_from()`` api, the stream is one a one | ||||
|             # time use and we couple the far end tasks's lifetime to | ||||
|             # the consumer's scope; we don't ever send a `'stop'` | ||||
|             # message right now since there shouldn't be a reason to | ||||
|             # stop and restart the stream, right? | ||||
|             try: | ||||
|                 await ctx.cancel() | ||||
|             except trio.ClosedResourceError: | ||||
|  | @ -326,16 +335,55 @@ class Portal: | |||
| 
 | ||||
|             self._streams.remove(rchan) | ||||
| 
 | ||||
|     # @asynccontextmanager | ||||
|     # async def open_context( | ||||
|     #     self, | ||||
|     #     func: Callable, | ||||
|     #     **kwargs, | ||||
|     # ) -> Context: | ||||
|     #     # TODO | ||||
|     #     elif resptype == 'context':  # context manager style setup/teardown | ||||
|     #         # TODO likely not here though | ||||
|     #         raise NotImplementedError | ||||
|     @asynccontextmanager | ||||
|     async def open_context( | ||||
|         self, | ||||
|         func: Callable, | ||||
|         **kwargs, | ||||
|     ) -> Context: | ||||
|         """Open an inter-actor task context. | ||||
| 
 | ||||
|         This is a synchronous API which allows for deterministic | ||||
|         setup/teardown of a remote task. The yielded ``Context`` further | ||||
|         allows for opening bidirectional streams - see | ||||
|         ``Context.open_stream()``. | ||||
| 
 | ||||
|         """ | ||||
|         # conduct target func method structural checks | ||||
|         if not inspect.iscoroutinefunction(func) and ( | ||||
|             getattr(func, '_tractor_contex_function', False) | ||||
|         ): | ||||
|             raise TypeError( | ||||
|                 f'{func} must be an async generator function!') | ||||
| 
 | ||||
|         fn_mod_path, fn_name = func_deats(func) | ||||
| 
 | ||||
|         cid, recv_chan, functype, first_msg = await self._submit( | ||||
|             fn_mod_path, fn_name, kwargs) | ||||
| 
 | ||||
|         assert functype == 'context' | ||||
| 
 | ||||
|         msg = await recv_chan.receive() | ||||
|         try: | ||||
|             # the "first" value here is delivered by the callee's | ||||
|             # ``Context.started()`` call. | ||||
|             first = msg['started'] | ||||
| 
 | ||||
|         except KeyError: | ||||
|             assert msg.get('cid'), ("Received internal error at context?") | ||||
| 
 | ||||
|             if msg.get('error'): | ||||
|                 # raise the error message | ||||
|                 raise unpack_error(msg, self.channel) | ||||
|             else: | ||||
|                 raise | ||||
|         try: | ||||
|             ctx = Context(self.channel, cid, _portal=self) | ||||
|             yield ctx, first | ||||
| 
 | ||||
|         finally: | ||||
|             await recv_chan.aclose() | ||||
|             await ctx.cancel() | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
|  |  | |||
|  | @ -1,19 +1,195 @@ | |||
| import inspect | ||||
| from contextlib import contextmanager  # , asynccontextmanager | ||||
| from contextlib import contextmanager, asynccontextmanager | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Iterator, Optional | ||||
| from typing import Any, Iterator, Optional, Callable | ||||
| import warnings | ||||
| 
 | ||||
| import trio | ||||
| 
 | ||||
| from ._ipc import Channel | ||||
| from ._exceptions import unpack_error | ||||
| from ._state import current_actor | ||||
| from .log import get_logger | ||||
| 
 | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| # TODO: generic typing like trio's receive channel | ||||
| # but with msgspec messages? | ||||
| # class ReceiveChannel(AsyncResource, Generic[ReceiveType]): | ||||
| 
 | ||||
| 
 | ||||
| class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||
|     """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with | ||||
|     special behaviour for signalling stream termination across an | ||||
|     inter-actor ``Channel``. This is the type returned to a local task | ||||
|     which invoked a remote streaming function using `Portal.run()`. | ||||
| 
 | ||||
|     Termination rules: | ||||
|     - if the local task signals stop iteration a cancel signal is | ||||
|       relayed to the remote task indicating to stop streaming | ||||
|     - if the remote task signals the end of a stream, raise a | ||||
|       ``StopAsyncIteration`` to terminate the local ``async for`` | ||||
| 
 | ||||
|     """ | ||||
|     def __init__( | ||||
|         self, | ||||
|         ctx: 'Context',  # typing: ignore # noqa | ||||
|         rx_chan: trio.abc.ReceiveChannel, | ||||
|     ) -> None: | ||||
|         self._ctx = ctx | ||||
|         self._rx_chan = rx_chan | ||||
|         self._shielded = False | ||||
| 
 | ||||
|     # delegate directly to underlying mem channel | ||||
|     def receive_nowait(self): | ||||
|         return self._rx_chan.receive_nowait() | ||||
| 
 | ||||
|     async def receive(self): | ||||
|         try: | ||||
|             msg = await self._rx_chan.receive() | ||||
|             return msg['yield'] | ||||
| 
 | ||||
|         except KeyError: | ||||
|             # internal error should never get here | ||||
|             assert msg.get('cid'), ("Received internal error at portal?") | ||||
| 
 | ||||
|             # TODO: handle 2 cases with 3.10 match syntax | ||||
|             # - 'stop' | ||||
|             # - 'error' | ||||
|             # possibly just handle msg['stop'] here! | ||||
| 
 | ||||
|             if msg.get('stop'): | ||||
|                 log.debug(f"{self} was stopped at remote end") | ||||
|                 # when the send is closed we assume the stream has | ||||
|                 # terminated and signal this local iterator to stop | ||||
|                 await self.aclose() | ||||
|                 raise trio.EndOfChannel | ||||
| 
 | ||||
|             # TODO: test that shows stream raising an expected error!!! | ||||
|             elif msg.get('error'): | ||||
|                 # raise the error message | ||||
|                 raise unpack_error(msg, self._ctx.chan) | ||||
| 
 | ||||
|             else: | ||||
|                 raise | ||||
| 
 | ||||
|         except (trio.ClosedResourceError, StopAsyncIteration): | ||||
|             # XXX: this indicates that a `stop` message was | ||||
|             # sent by the far side of the underlying channel. | ||||
|             # Currently this is triggered by calling ``.aclose()`` on | ||||
|             # the send side of the channel inside | ||||
|             # ``Actor._push_result()``, but maybe it should be put here? | ||||
|             # to avoid exposing the internal mem chan closing mechanism? | ||||
|             # in theory we could instead do some flushing of the channel | ||||
|             # if needed to ensure all consumers are complete before | ||||
|             # triggering closure too early? | ||||
| 
 | ||||
|             # Locally, we want to close this stream gracefully, by | ||||
|             # terminating any local consumers tasks deterministically. | ||||
|             # We **don't** want to be closing this send channel and not | ||||
|             # relaying a final value to remaining consumers who may not | ||||
|             # have been scheduled to receive it yet? | ||||
| 
 | ||||
|             # lots of testing to do here | ||||
| 
 | ||||
|             # when the send is closed we assume the stream has | ||||
|             # terminated and signal this local iterator to stop | ||||
|             await self.aclose() | ||||
|             # await self._ctx.send_stop() | ||||
|             raise StopAsyncIteration | ||||
| 
 | ||||
|         except trio.Cancelled: | ||||
|             # relay cancels to the remote task | ||||
|             await self.aclose() | ||||
|             raise | ||||
| 
 | ||||
|     @contextmanager | ||||
|     def shield( | ||||
|         self | ||||
|     ) -> Iterator['ReceiveMsgStream']:  # noqa | ||||
|         """Shield this stream's underlying channel such that a local consumer task | ||||
|         can be cancelled (and possibly restarted) using ``trio.Cancelled``. | ||||
| 
 | ||||
|         """ | ||||
|         self._shielded = True | ||||
|         yield self | ||||
|         self._shielded = False | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         """Cancel associated remote actor task and local memory channel | ||||
|         on close. | ||||
| 
 | ||||
|         """ | ||||
|         # TODO: proper adherance to trio's `.aclose()` semantics: | ||||
|         # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose | ||||
|         rx_chan = self._rx_chan | ||||
| 
 | ||||
|         if rx_chan._closed: | ||||
|             log.warning(f"{self} is already closed") | ||||
|             return | ||||
| 
 | ||||
|         # TODO: broadcasting to multiple consumers | ||||
|         # stats = rx_chan.statistics() | ||||
|         # if stats.open_receive_channels > 1: | ||||
|         #     # if we've been cloned don't kill the stream | ||||
|         #     log.debug( | ||||
|         #       "there are still consumers running keeping stream alive") | ||||
|         #     return | ||||
| 
 | ||||
|         if self._shielded: | ||||
|             log.warning(f"{self} is shielded, portal channel being kept alive") | ||||
|             return | ||||
| 
 | ||||
|         # NOTE: this is super subtle IPC messaging stuff: | ||||
|         # Relay stop iteration to far end **iff** we're | ||||
|         # in bidirectional mode. If we're only streaming | ||||
|         # *from* one side then that side **won't** have an | ||||
|         # entry in `Actor._cids2qs` (maybe it should though?). | ||||
|         # So any `yield` or `stop` msgs sent from the caller side | ||||
|         # will cause key errors on the callee side since there is | ||||
|         # no entry for a local feeder mem chan since the callee task | ||||
|         # isn't expecting messages to be sent by the caller. | ||||
|         # Thus, we must check that this context DOES NOT | ||||
|         # have a portal reference to ensure this is indeed the callee | ||||
|         # side and can relay a 'stop'. In the bidirectional case, | ||||
|         # `Context.open_stream()` will create the `Actor._cids2qs` | ||||
|         # entry from a call to `Actor.get_memchans()`. | ||||
|         if not self._ctx._portal: | ||||
|             # only for 2 way streams can we can send | ||||
|             # stop from the caller side | ||||
|             await self._ctx.send_stop() | ||||
| 
 | ||||
|         # close the local mem chan | ||||
|         rx_chan.close() | ||||
| 
 | ||||
|     # TODO: but make it broadcasting to consumers | ||||
|     # def clone(self): | ||||
|     #     """Clone this receive channel allowing for multi-task | ||||
|     #     consumption from the same channel. | ||||
| 
 | ||||
|     #     """ | ||||
|     #     return ReceiveStream( | ||||
|     #         self._cid, | ||||
|     #         self._rx_chan.clone(), | ||||
|     #         self._portal, | ||||
|     #     ) | ||||
| 
 | ||||
| 
 | ||||
| class MsgStream(ReceiveMsgStream, trio.abc.Channel): | ||||
|     """ | ||||
|     Bidirectional message stream for use within an inter-actor actor | ||||
|     ``Context```. | ||||
| 
 | ||||
|     """ | ||||
|     async def send( | ||||
|         self, | ||||
|         data: Any | ||||
|     ) -> None: | ||||
|         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass(frozen=True) | ||||
| class Context: | ||||
|     """An IAC (inter-actor communication) context. | ||||
|  | @ -31,6 +207,10 @@ class Context: | |||
|     chan: Channel | ||||
|     cid: str | ||||
| 
 | ||||
|     # TODO: should we have seperate types for caller vs. callee | ||||
|     # side contexts? The caller always opens a portal whereas the callee | ||||
|     # is always responding back through a context-stream | ||||
| 
 | ||||
|     # only set on the caller side | ||||
|     _portal: Optional['Portal'] = None    # type: ignore # noqa | ||||
| 
 | ||||
|  | @ -57,46 +237,97 @@ class Context: | |||
|         timeout quickly to sidestep 2-generals... | ||||
| 
 | ||||
|         """ | ||||
|         assert self._portal, ( | ||||
|             "No portal found, this is likely a callee side context") | ||||
|         if self._portal:  # caller side: | ||||
|             if not self._portal: | ||||
|                 raise RuntimeError( | ||||
|                     "No portal found, this is likely a callee side context" | ||||
|                 ) | ||||
| 
 | ||||
|         cid = self.cid | ||||
|         with trio.move_on_after(0.5) as cs: | ||||
|             cs.shield = True | ||||
|             log.warning( | ||||
|                 f"Cancelling stream {cid} to " | ||||
|                 f"{self._portal.channel.uid}") | ||||
| 
 | ||||
|             # NOTE: we're telling the far end actor to cancel a task | ||||
|             # corresponding to *this actor*. The far end local channel | ||||
|             # instance is passed to `Actor._cancel_task()` implicitly. | ||||
|             await self._portal.run_from_ns('self', '_cancel_task', cid=cid) | ||||
| 
 | ||||
|         if cs.cancelled_caught: | ||||
|             # XXX: there's no way to know if the remote task was indeed | ||||
|             # cancelled in the case where the connection is broken or | ||||
|             # some other network error occurred. | ||||
|             if not self._portal.channel.connected(): | ||||
|             cid = self.cid | ||||
|             with trio.move_on_after(0.5) as cs: | ||||
|                 cs.shield = True | ||||
|                 log.warning( | ||||
|                     "May have failed to cancel remote task " | ||||
|                     f"{cid} for {self._portal.channel.uid}") | ||||
|                     f"Cancelling stream {cid} to " | ||||
|                     f"{self._portal.channel.uid}") | ||||
| 
 | ||||
|                 # NOTE: we're telling the far end actor to cancel a task | ||||
|                 # corresponding to *this actor*. The far end local channel | ||||
|                 # instance is passed to `Actor._cancel_task()` implicitly. | ||||
|                 await self._portal.run_from_ns('self', '_cancel_task', cid=cid) | ||||
| 
 | ||||
|             if cs.cancelled_caught: | ||||
|                 # XXX: there's no way to know if the remote task was indeed | ||||
|                 # cancelled in the case where the connection is broken or | ||||
|                 # some other network error occurred. | ||||
|                 # if not self._portal.channel.connected(): | ||||
|                 if not self.chan.connected(): | ||||
|                     log.warning( | ||||
|                         "May have failed to cancel remote task " | ||||
|                         f"{cid} for {self._portal.channel.uid}") | ||||
|         else: | ||||
|             # ensure callee side | ||||
|             assert self._cancel_scope | ||||
|             # TODO: should we have an explicit cancel message | ||||
|             # or is relaying the local `trio.Cancelled` as an | ||||
|             # {'error': trio.Cancelled, cid: "blah"} enough? | ||||
|             # This probably gets into the discussion in | ||||
|             # https://github.com/goodboy/tractor/issues/36 | ||||
|             self._cancel_scope.cancel() | ||||
| 
 | ||||
|     # TODO: do we need a restart api? | ||||
|     # async def restart(self) -> None: | ||||
|     #     # TODO | ||||
|     #     pass | ||||
| 
 | ||||
|     # @asynccontextmanager | ||||
|     # async def open_stream( | ||||
|     #     self, | ||||
|     # ) -> AsyncContextManager: | ||||
|     #     # TODO | ||||
|     #     pass | ||||
|     @asynccontextmanager | ||||
|     async def open_stream( | ||||
|         self, | ||||
|     ) -> MsgStream: | ||||
|         # TODO | ||||
| 
 | ||||
|         actor = current_actor() | ||||
| 
 | ||||
|         # here we create a mem chan that corresponds to the | ||||
|         # far end caller / callee. | ||||
| 
 | ||||
|         # NOTE: in one way streaming this only happens on the | ||||
|         # caller side inside `Actor.send_cmd()` so if you try | ||||
|         # to send a stop from the caller to the callee in the | ||||
|         # single-direction-stream case you'll get a lookup error | ||||
|         # currently. | ||||
|         _, recv_chan = actor.get_memchans( | ||||
|             self.chan.uid, | ||||
|             self.cid | ||||
|         ) | ||||
| 
 | ||||
|         async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan: | ||||
| 
 | ||||
|             if self._portal: | ||||
|                 self._portal._streams.add(rchan) | ||||
| 
 | ||||
|             try: | ||||
|                 yield rchan | ||||
| 
 | ||||
|             finally: | ||||
|                 await self.send_stop() | ||||
|                 if self._portal: | ||||
|                     self._portal._streams.add(rchan) | ||||
| 
 | ||||
|     async def started(self, value: Any) -> None: | ||||
| 
 | ||||
|         if self._portal: | ||||
|             raise RuntimeError( | ||||
|                 f"Caller side context {self} can not call started!") | ||||
| 
 | ||||
|         await self.chan.send({'started': value, 'cid': self.cid}) | ||||
| 
 | ||||
| 
 | ||||
| def stream(func): | ||||
| def stream(func: Callable) -> Callable: | ||||
|     """Mark an async function as a streaming routine with ``@stream``. | ||||
| 
 | ||||
|     """ | ||||
|     # annotate | ||||
|     func._tractor_stream_function = True | ||||
| 
 | ||||
|     sig = inspect.signature(func) | ||||
|     params = sig.parameters | ||||
|     if 'stream' not in params and 'ctx' in params: | ||||
|  | @ -114,147 +345,24 @@ def stream(func): | |||
|     ): | ||||
|         raise TypeError( | ||||
|             "The first argument to the stream function " | ||||
|             f"{func.__name__} must be `ctx: tractor.Context`" | ||||
|             f"{func.__name__} must be `ctx: tractor.Context` " | ||||
|             "(Or ``to_trio`` if using ``asyncio`` in guest mode)." | ||||
|         ) | ||||
|     return func | ||||
| 
 | ||||
| 
 | ||||
| class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||
|     """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with | ||||
|     special behaviour for signalling stream termination across an | ||||
|     inter-actor ``Channel``. This is the type returned to a local task | ||||
|     which invoked a remote streaming function using `Portal.run()`. | ||||
| 
 | ||||
|     Termination rules: | ||||
|     - if the local task signals stop iteration a cancel signal is | ||||
|       relayed to the remote task indicating to stop streaming | ||||
|     - if the remote task signals the end of a stream, raise a | ||||
|       ``StopAsyncIteration`` to terminate the local ``async for`` | ||||
| def context(func: Callable) -> Callable: | ||||
|     """Mark an async function as a streaming routine with ``@context``. | ||||
| 
 | ||||
|     """ | ||||
|     def __init__( | ||||
|         self, | ||||
|         ctx: Context, | ||||
|         rx_chan: trio.abc.ReceiveChannel, | ||||
|         portal: 'Portal',  # type: ignore # noqa | ||||
|     ) -> None: | ||||
|         self._ctx = ctx | ||||
|         self._rx_chan = rx_chan | ||||
|         self._portal = portal | ||||
|         self._shielded = False | ||||
|     # annotate | ||||
|     func._tractor_context_function = True | ||||
| 
 | ||||
|     # delegate directly to underlying mem channel | ||||
|     def receive_nowait(self): | ||||
|         return self._rx_chan.receive_nowait() | ||||
| 
 | ||||
|     async def receive(self): | ||||
|         try: | ||||
|             msg = await self._rx_chan.receive() | ||||
|             return msg['yield'] | ||||
| 
 | ||||
|         except KeyError: | ||||
|             # internal error should never get here | ||||
|             assert msg.get('cid'), ("Received internal error at portal?") | ||||
| 
 | ||||
|             # TODO: handle 2 cases with 3.10 match syntax | ||||
|             # - 'stop' | ||||
|             # - 'error' | ||||
|             # possibly just handle msg['stop'] here! | ||||
| 
 | ||||
|             # TODO: test that shows stream raising an expected error!!! | ||||
|             if msg.get('error'): | ||||
|                 # raise the error message | ||||
|                 raise unpack_error(msg, self._portal.channel) | ||||
| 
 | ||||
|         except (trio.ClosedResourceError, StopAsyncIteration): | ||||
|             # XXX: this indicates that a `stop` message was | ||||
|             # sent by the far side of the underlying channel. | ||||
|             # Currently this is triggered by calling ``.aclose()`` on | ||||
|             # the send side of the channel inside | ||||
|             # ``Actor._push_result()``, but maybe it should be put here? | ||||
|             # to avoid exposing the internal mem chan closing mechanism? | ||||
|             # in theory we could instead do some flushing of the channel | ||||
|             # if needed to ensure all consumers are complete before | ||||
|             # triggering closure too early? | ||||
| 
 | ||||
|             # Locally, we want to close this stream gracefully, by | ||||
|             # terminating any local consumers tasks deterministically. | ||||
|             # We **don't** want to be closing this send channel and not | ||||
|             # relaying a final value to remaining consumers who may not | ||||
|             # have been scheduled to receive it yet? | ||||
| 
 | ||||
|             # lots of testing to do here | ||||
| 
 | ||||
|             # when the send is closed we assume the stream has | ||||
|             # terminated and signal this local iterator to stop | ||||
|             await self.aclose() | ||||
|             raise StopAsyncIteration | ||||
| 
 | ||||
|         except trio.Cancelled: | ||||
|             # relay cancels to the remote task | ||||
|             await self.aclose() | ||||
|             raise | ||||
| 
 | ||||
|     @contextmanager | ||||
|     def shield( | ||||
|         self | ||||
|     ) -> Iterator['ReceiveMsgStream']:  # noqa | ||||
|         """Shield this stream's underlying channel such that a local consumer task | ||||
|         can be cancelled (and possibly restarted) using ``trio.Cancelled``. | ||||
| 
 | ||||
|         """ | ||||
|         self._shielded = True | ||||
|         yield self | ||||
|         self._shielded = False | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         """Cancel associated remote actor task and local memory channel | ||||
|         on close. | ||||
|         """ | ||||
|         rx_chan = self._rx_chan | ||||
| 
 | ||||
|         if rx_chan._closed: | ||||
|             log.warning(f"{self} is already closed") | ||||
|             return | ||||
| 
 | ||||
|         # stats = rx_chan.statistics() | ||||
|         # if stats.open_receive_channels > 1: | ||||
|         #     # if we've been cloned don't kill the stream | ||||
|         #     log.debug( | ||||
|         #       "there are still consumers running keeping stream alive") | ||||
|         #     return | ||||
| 
 | ||||
|         if self._shielded: | ||||
|             log.warning(f"{self} is shielded, portal channel being kept alive") | ||||
|             return | ||||
| 
 | ||||
|         # close the local mem chan | ||||
|         rx_chan.close() | ||||
| 
 | ||||
|         # cancel surrounding IPC context | ||||
|         await self._ctx.cancel() | ||||
| 
 | ||||
|     # TODO: but make it broadcasting to consumers | ||||
|     # def clone(self): | ||||
|     #     """Clone this receive channel allowing for multi-task | ||||
|     #     consumption from the same channel. | ||||
| 
 | ||||
|     #     """ | ||||
|     #     return ReceiveStream( | ||||
|     #         self._cid, | ||||
|     #         self._rx_chan.clone(), | ||||
|     #         self._portal, | ||||
|     #     ) | ||||
| 
 | ||||
| 
 | ||||
| # class MsgStream(ReceiveMsgStream, trio.abc.Channel): | ||||
| #     """ | ||||
| #     Bidirectional message stream for use within an inter-actor actor | ||||
| #     ``Context```. | ||||
| 
 | ||||
| #     """ | ||||
| #     async def send( | ||||
| #         self, | ||||
| #         data: Any | ||||
| #     ) -> None: | ||||
| #         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) | ||||
|     sig = inspect.signature(func) | ||||
|     params = sig.parameters | ||||
|     if 'ctx' not in params: | ||||
|         raise TypeError( | ||||
|             "The first argument to the context function " | ||||
|             f"{func.__name__} must be `ctx: tractor.Context`" | ||||
|         ) | ||||
|     return func | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue