Add initial bi-directional streaming
This mostly adds the api described in https://github.com/goodboy/tractor/issues/53#issuecomment-806258798 The first draft summary: - formalize bidir steaming using the `trio.Channel` style interface which we derive as a `MsgStream` type. - add `Portal.open_context()` which provides a `trio.Nursery.start()` remote task invocation style for setting up and tearing down tasks contexts in remote actors. - add a distinct `'started'` message to the ipc protocol to facilitate `Context.start()` with a first return value. - for our `ReceiveMsgStream` type, don't cancel the remote task in `.aclose()`; this is now done explicitly by the surrounding `Context` usage: `Context.cancel()`. - streams in either direction still use a `'yield'` message keeping the proto mostly symmetric without having to worry about which side is the caller / portal opener. - subtlety: only allow sending a `'stop'` message during a 2-way streaming context from `ReceiveStream.aclose()`, detailed comment with explanation is included. Relates to #53transport_hardening
							parent
							
								
									76f07898d9
								
							
						
					
					
						commit
						2870828c34
					
				|  | @ -14,6 +14,7 @@ from types import ModuleType | ||||||
| import sys | import sys | ||||||
| import os | import os | ||||||
| from contextlib import ExitStack | from contextlib import ExitStack | ||||||
|  | import warnings | ||||||
| 
 | 
 | ||||||
| import trio  # type: ignore | import trio  # type: ignore | ||||||
| from trio_typing import TaskStatus | from trio_typing import TaskStatus | ||||||
|  | @ -58,13 +59,37 @@ async def _invoke( | ||||||
|     treat_as_gen = False |     treat_as_gen = False | ||||||
|     cs = None |     cs = None | ||||||
|     cancel_scope = trio.CancelScope() |     cancel_scope = trio.CancelScope() | ||||||
|     ctx = Context(chan, cid, cancel_scope) |     ctx = Context(chan, cid, _cancel_scope=cancel_scope) | ||||||
|  |     context = False | ||||||
| 
 | 
 | ||||||
|     if getattr(func, '_tractor_stream_function', False): |     if getattr(func, '_tractor_stream_function', False): | ||||||
|         # handle decorated ``@tractor.stream`` async functions |         # handle decorated ``@tractor.stream`` async functions | ||||||
|  |         sig = inspect.signature(func) | ||||||
|  |         params = sig.parameters | ||||||
|  | 
 | ||||||
|  |         # compat with old api | ||||||
|         kwargs['ctx'] = ctx |         kwargs['ctx'] = ctx | ||||||
|  | 
 | ||||||
|  |         if 'ctx' in params: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "`@tractor.stream decorated funcs should now declare " | ||||||
|  |                 "a `stream`  arg, `ctx` is now designated for use with " | ||||||
|  |                 "@tractor.context", | ||||||
|  |                 DeprecationWarning, | ||||||
|  |                 stacklevel=2, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         elif 'stream' in params: | ||||||
|  |             assert 'stream' in params | ||||||
|  |             kwargs['stream'] = ctx | ||||||
|  | 
 | ||||||
|         treat_as_gen = True |         treat_as_gen = True | ||||||
| 
 | 
 | ||||||
|  |     elif getattr(func, '_tractor_context_function', False): | ||||||
|  |         # handle decorated ``@tractor.context`` async function | ||||||
|  |         kwargs['ctx'] = ctx | ||||||
|  |         context = True | ||||||
|  | 
 | ||||||
|     # errors raised inside this block are propgated back to caller |     # errors raised inside this block are propgated back to caller | ||||||
|     try: |     try: | ||||||
|         if not ( |         if not ( | ||||||
|  | @ -102,8 +127,9 @@ async def _invoke( | ||||||
|             # `StopAsyncIteration` system here for returning a final |             # `StopAsyncIteration` system here for returning a final | ||||||
|             # value if desired |             # value if desired | ||||||
|             await chan.send({'stop': True, 'cid': cid}) |             await chan.send({'stop': True, 'cid': cid}) | ||||||
|         else: | 
 | ||||||
|             if treat_as_gen: |         # one way @stream func that gets treated like an async gen | ||||||
|  |         elif treat_as_gen: | ||||||
|             await chan.send({'functype': 'asyncgen', 'cid': cid}) |             await chan.send({'functype': 'asyncgen', 'cid': cid}) | ||||||
|             # XXX: the async-func may spawn further tasks which push |             # XXX: the async-func may spawn further tasks which push | ||||||
|             # back values like an async-generator would but must |             # back values like an async-generator would but must | ||||||
|  | @ -112,10 +138,24 @@ async def _invoke( | ||||||
|             with cancel_scope as cs: |             with cancel_scope as cs: | ||||||
|                 task_status.started(cs) |                 task_status.started(cs) | ||||||
|                 await coro |                 await coro | ||||||
|  | 
 | ||||||
|             if not cs.cancelled_caught: |             if not cs.cancelled_caught: | ||||||
|                 # task was not cancelled so we can instruct the |                 # task was not cancelled so we can instruct the | ||||||
|                 # far end async gen to tear down |                 # far end async gen to tear down | ||||||
|                 await chan.send({'stop': True, 'cid': cid}) |                 await chan.send({'stop': True, 'cid': cid}) | ||||||
|  | 
 | ||||||
|  |         elif context: | ||||||
|  |             # context func with support for bi-dir streaming | ||||||
|  |             await chan.send({'functype': 'context', 'cid': cid}) | ||||||
|  | 
 | ||||||
|  |             with cancel_scope as cs: | ||||||
|  |                 task_status.started(cs) | ||||||
|  |                 await chan.send({'return': await coro, 'cid': cid}) | ||||||
|  | 
 | ||||||
|  |             # if cs.cancelled_caught: | ||||||
|  |             #     # task was cancelled so relay to the cancel to caller | ||||||
|  |             #     await chan.send({'return': await coro, 'cid': cid}) | ||||||
|  | 
 | ||||||
|         else: |         else: | ||||||
|             # regular async function |             # regular async function | ||||||
|             await chan.send({'functype': 'asyncfunc', 'cid': cid}) |             await chan.send({'functype': 'asyncfunc', 'cid': cid}) | ||||||
|  | @ -417,10 +457,10 @@ class Actor: | ||||||
|         send_chan, recv_chan = self._cids2qs[(actorid, cid)] |         send_chan, recv_chan = self._cids2qs[(actorid, cid)] | ||||||
|         assert send_chan.cid == cid  # type: ignore |         assert send_chan.cid == cid  # type: ignore | ||||||
| 
 | 
 | ||||||
|         if 'stop' in msg: |         # if 'stop' in msg: | ||||||
|             log.debug(f"{send_chan} was terminated at remote end") |         #     log.debug(f"{send_chan} was terminated at remote end") | ||||||
|             # indicate to consumer that far end has stopped |         #     # indicate to consumer that far end has stopped | ||||||
|             return await send_chan.aclose() |         #     return await send_chan.aclose() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             log.debug(f"Delivering {msg} from {actorid} to caller {cid}") |             log.debug(f"Delivering {msg} from {actorid} to caller {cid}") | ||||||
|  | @ -428,6 +468,12 @@ class Actor: | ||||||
|             await send_chan.send(msg) |             await send_chan.send(msg) | ||||||
| 
 | 
 | ||||||
|         except trio.BrokenResourceError: |         except trio.BrokenResourceError: | ||||||
|  |             # TODO: what is the right way to handle the case where the | ||||||
|  |             # local task has already sent a 'stop' / StopAsyncInteration | ||||||
|  |             # to the other side but and possibly has closed the local | ||||||
|  |             # feeder mem chan? Do we wait for some kind of ack or just | ||||||
|  |             # let this fail silently and bubble up (currently)? | ||||||
|  | 
 | ||||||
|             # XXX: local consumer has closed their side |             # XXX: local consumer has closed their side | ||||||
|             # so cancel the far end streaming task |             # so cancel the far end streaming task | ||||||
|             log.warning(f"{send_chan} consumer is already closed") |             log.warning(f"{send_chan} consumer is already closed") | ||||||
|  | @ -508,6 +554,7 @@ class Actor: | ||||||
|                     if cid: |                     if cid: | ||||||
|                         # deliver response to local caller/waiter |                         # deliver response to local caller/waiter | ||||||
|                         await self._push_result(chan, cid, msg) |                         await self._push_result(chan, cid, msg) | ||||||
|  | 
 | ||||||
|                         log.debug( |                         log.debug( | ||||||
|                             f"Waiting on next msg for {chan} from {chan.uid}") |                             f"Waiting on next msg for {chan} from {chan.uid}") | ||||||
|                         continue |                         continue | ||||||
|  |  | ||||||
|  | @ -312,11 +312,20 @@ class Portal: | ||||||
| 
 | 
 | ||||||
|         ctx = Context(self.channel, cid, _portal=self) |         ctx = Context(self.channel, cid, _portal=self) | ||||||
|         try: |         try: | ||||||
|             async with ReceiveMsgStream(ctx, recv_chan, self) as rchan: |             # deliver receive only stream | ||||||
|  |             async with ReceiveMsgStream(ctx, recv_chan) as rchan: | ||||||
|                 self._streams.add(rchan) |                 self._streams.add(rchan) | ||||||
|                 yield rchan |                 yield rchan | ||||||
|  | 
 | ||||||
|         finally: |         finally: | ||||||
|  | 
 | ||||||
|             # cancel the far end task on consumer close |             # cancel the far end task on consumer close | ||||||
|  |             # NOTE: this is a special case since we assume that if using | ||||||
|  |             # this ``.open_fream_from()`` api, the stream is one a one | ||||||
|  |             # time use and we couple the far end tasks's lifetime to | ||||||
|  |             # the consumer's scope; we don't ever send a `'stop'` | ||||||
|  |             # message right now since there shouldn't be a reason to | ||||||
|  |             # stop and restart the stream, right? | ||||||
|             try: |             try: | ||||||
|                 await ctx.cancel() |                 await ctx.cancel() | ||||||
|             except trio.ClosedResourceError: |             except trio.ClosedResourceError: | ||||||
|  | @ -326,16 +335,55 @@ class Portal: | ||||||
| 
 | 
 | ||||||
|             self._streams.remove(rchan) |             self._streams.remove(rchan) | ||||||
| 
 | 
 | ||||||
|     # @asynccontextmanager |     @asynccontextmanager | ||||||
|     # async def open_context( |     async def open_context( | ||||||
|     #     self, |         self, | ||||||
|     #     func: Callable, |         func: Callable, | ||||||
|     #     **kwargs, |         **kwargs, | ||||||
|     # ) -> Context: |     ) -> Context: | ||||||
|     #     # TODO |         """Open an inter-actor task context. | ||||||
|     #     elif resptype == 'context':  # context manager style setup/teardown | 
 | ||||||
|     #         # TODO likely not here though |         This is a synchronous API which allows for deterministic | ||||||
|     #         raise NotImplementedError |         setup/teardown of a remote task. The yielded ``Context`` further | ||||||
|  |         allows for opening bidirectional streams - see | ||||||
|  |         ``Context.open_stream()``. | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         # conduct target func method structural checks | ||||||
|  |         if not inspect.iscoroutinefunction(func) and ( | ||||||
|  |             getattr(func, '_tractor_contex_function', False) | ||||||
|  |         ): | ||||||
|  |             raise TypeError( | ||||||
|  |                 f'{func} must be an async generator function!') | ||||||
|  | 
 | ||||||
|  |         fn_mod_path, fn_name = func_deats(func) | ||||||
|  | 
 | ||||||
|  |         cid, recv_chan, functype, first_msg = await self._submit( | ||||||
|  |             fn_mod_path, fn_name, kwargs) | ||||||
|  | 
 | ||||||
|  |         assert functype == 'context' | ||||||
|  | 
 | ||||||
|  |         msg = await recv_chan.receive() | ||||||
|  |         try: | ||||||
|  |             # the "first" value here is delivered by the callee's | ||||||
|  |             # ``Context.started()`` call. | ||||||
|  |             first = msg['started'] | ||||||
|  | 
 | ||||||
|  |         except KeyError: | ||||||
|  |             assert msg.get('cid'), ("Received internal error at context?") | ||||||
|  | 
 | ||||||
|  |             if msg.get('error'): | ||||||
|  |                 # raise the error message | ||||||
|  |                 raise unpack_error(msg, self.channel) | ||||||
|  |             else: | ||||||
|  |                 raise | ||||||
|  |         try: | ||||||
|  |             ctx = Context(self.channel, cid, _portal=self) | ||||||
|  |             yield ctx, first | ||||||
|  | 
 | ||||||
|  |         finally: | ||||||
|  |             await recv_chan.aclose() | ||||||
|  |             await ctx.cancel() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
|  |  | ||||||
|  | @ -1,19 +1,195 @@ | ||||||
| import inspect | import inspect | ||||||
| from contextlib import contextmanager  # , asynccontextmanager | from contextlib import contextmanager, asynccontextmanager | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Any, Iterator, Optional | from typing import Any, Iterator, Optional, Callable | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| import trio | import trio | ||||||
| 
 | 
 | ||||||
| from ._ipc import Channel | from ._ipc import Channel | ||||||
| from ._exceptions import unpack_error | from ._exceptions import unpack_error | ||||||
|  | from ._state import current_actor | ||||||
| from .log import get_logger | from .log import get_logger | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| log = get_logger(__name__) | log = get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # TODO: generic typing like trio's receive channel | ||||||
|  | # but with msgspec messages? | ||||||
|  | # class ReceiveChannel(AsyncResource, Generic[ReceiveType]): | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|  |     """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with | ||||||
|  |     special behaviour for signalling stream termination across an | ||||||
|  |     inter-actor ``Channel``. This is the type returned to a local task | ||||||
|  |     which invoked a remote streaming function using `Portal.run()`. | ||||||
|  | 
 | ||||||
|  |     Termination rules: | ||||||
|  |     - if the local task signals stop iteration a cancel signal is | ||||||
|  |       relayed to the remote task indicating to stop streaming | ||||||
|  |     - if the remote task signals the end of a stream, raise a | ||||||
|  |       ``StopAsyncIteration`` to terminate the local ``async for`` | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         ctx: 'Context',  # typing: ignore # noqa | ||||||
|  |         rx_chan: trio.abc.ReceiveChannel, | ||||||
|  |     ) -> None: | ||||||
|  |         self._ctx = ctx | ||||||
|  |         self._rx_chan = rx_chan | ||||||
|  |         self._shielded = False | ||||||
|  | 
 | ||||||
|  |     # delegate directly to underlying mem channel | ||||||
|  |     def receive_nowait(self): | ||||||
|  |         return self._rx_chan.receive_nowait() | ||||||
|  | 
 | ||||||
|  |     async def receive(self): | ||||||
|  |         try: | ||||||
|  |             msg = await self._rx_chan.receive() | ||||||
|  |             return msg['yield'] | ||||||
|  | 
 | ||||||
|  |         except KeyError: | ||||||
|  |             # internal error should never get here | ||||||
|  |             assert msg.get('cid'), ("Received internal error at portal?") | ||||||
|  | 
 | ||||||
|  |             # TODO: handle 2 cases with 3.10 match syntax | ||||||
|  |             # - 'stop' | ||||||
|  |             # - 'error' | ||||||
|  |             # possibly just handle msg['stop'] here! | ||||||
|  | 
 | ||||||
|  |             if msg.get('stop'): | ||||||
|  |                 log.debug(f"{self} was stopped at remote end") | ||||||
|  |                 # when the send is closed we assume the stream has | ||||||
|  |                 # terminated and signal this local iterator to stop | ||||||
|  |                 await self.aclose() | ||||||
|  |                 raise trio.EndOfChannel | ||||||
|  | 
 | ||||||
|  |             # TODO: test that shows stream raising an expected error!!! | ||||||
|  |             elif msg.get('error'): | ||||||
|  |                 # raise the error message | ||||||
|  |                 raise unpack_error(msg, self._ctx.chan) | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  |         except (trio.ClosedResourceError, StopAsyncIteration): | ||||||
|  |             # XXX: this indicates that a `stop` message was | ||||||
|  |             # sent by the far side of the underlying channel. | ||||||
|  |             # Currently this is triggered by calling ``.aclose()`` on | ||||||
|  |             # the send side of the channel inside | ||||||
|  |             # ``Actor._push_result()``, but maybe it should be put here? | ||||||
|  |             # to avoid exposing the internal mem chan closing mechanism? | ||||||
|  |             # in theory we could instead do some flushing of the channel | ||||||
|  |             # if needed to ensure all consumers are complete before | ||||||
|  |             # triggering closure too early? | ||||||
|  | 
 | ||||||
|  |             # Locally, we want to close this stream gracefully, by | ||||||
|  |             # terminating any local consumers tasks deterministically. | ||||||
|  |             # We **don't** want to be closing this send channel and not | ||||||
|  |             # relaying a final value to remaining consumers who may not | ||||||
|  |             # have been scheduled to receive it yet? | ||||||
|  | 
 | ||||||
|  |             # lots of testing to do here | ||||||
|  | 
 | ||||||
|  |             # when the send is closed we assume the stream has | ||||||
|  |             # terminated and signal this local iterator to stop | ||||||
|  |             await self.aclose() | ||||||
|  |             # await self._ctx.send_stop() | ||||||
|  |             raise StopAsyncIteration | ||||||
|  | 
 | ||||||
|  |         except trio.Cancelled: | ||||||
|  |             # relay cancels to the remote task | ||||||
|  |             await self.aclose() | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     @contextmanager | ||||||
|  |     def shield( | ||||||
|  |         self | ||||||
|  |     ) -> Iterator['ReceiveMsgStream']:  # noqa | ||||||
|  |         """Shield this stream's underlying channel such that a local consumer task | ||||||
|  |         can be cancelled (and possibly restarted) using ``trio.Cancelled``. | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         self._shielded = True | ||||||
|  |         yield self | ||||||
|  |         self._shielded = False | ||||||
|  | 
 | ||||||
|  |     async def aclose(self): | ||||||
|  |         """Cancel associated remote actor task and local memory channel | ||||||
|  |         on close. | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         # TODO: proper adherance to trio's `.aclose()` semantics: | ||||||
|  |         # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose | ||||||
|  |         rx_chan = self._rx_chan | ||||||
|  | 
 | ||||||
|  |         if rx_chan._closed: | ||||||
|  |             log.warning(f"{self} is already closed") | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         # TODO: broadcasting to multiple consumers | ||||||
|  |         # stats = rx_chan.statistics() | ||||||
|  |         # if stats.open_receive_channels > 1: | ||||||
|  |         #     # if we've been cloned don't kill the stream | ||||||
|  |         #     log.debug( | ||||||
|  |         #       "there are still consumers running keeping stream alive") | ||||||
|  |         #     return | ||||||
|  | 
 | ||||||
|  |         if self._shielded: | ||||||
|  |             log.warning(f"{self} is shielded, portal channel being kept alive") | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         # NOTE: this is super subtle IPC messaging stuff: | ||||||
|  |         # Relay stop iteration to far end **iff** we're | ||||||
|  |         # in bidirectional mode. If we're only streaming | ||||||
|  |         # *from* one side then that side **won't** have an | ||||||
|  |         # entry in `Actor._cids2qs` (maybe it should though?). | ||||||
|  |         # So any `yield` or `stop` msgs sent from the caller side | ||||||
|  |         # will cause key errors on the callee side since there is | ||||||
|  |         # no entry for a local feeder mem chan since the callee task | ||||||
|  |         # isn't expecting messages to be sent by the caller. | ||||||
|  |         # Thus, we must check that this context DOES NOT | ||||||
|  |         # have a portal reference to ensure this is indeed the callee | ||||||
|  |         # side and can relay a 'stop'. In the bidirectional case, | ||||||
|  |         # `Context.open_stream()` will create the `Actor._cids2qs` | ||||||
|  |         # entry from a call to `Actor.get_memchans()`. | ||||||
|  |         if not self._ctx._portal: | ||||||
|  |             # only for 2 way streams can we can send | ||||||
|  |             # stop from the caller side | ||||||
|  |             await self._ctx.send_stop() | ||||||
|  | 
 | ||||||
|  |         # close the local mem chan | ||||||
|  |         rx_chan.close() | ||||||
|  | 
 | ||||||
|  |     # TODO: but make it broadcasting to consumers | ||||||
|  |     # def clone(self): | ||||||
|  |     #     """Clone this receive channel allowing for multi-task | ||||||
|  |     #     consumption from the same channel. | ||||||
|  | 
 | ||||||
|  |     #     """ | ||||||
|  |     #     return ReceiveStream( | ||||||
|  |     #         self._cid, | ||||||
|  |     #         self._rx_chan.clone(), | ||||||
|  |     #         self._portal, | ||||||
|  |     #     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MsgStream(ReceiveMsgStream, trio.abc.Channel): | ||||||
|  |     """ | ||||||
|  |     Bidirectional message stream for use within an inter-actor actor | ||||||
|  |     ``Context```. | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  |     async def send( | ||||||
|  |         self, | ||||||
|  |         data: Any | ||||||
|  |     ) -> None: | ||||||
|  |         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclass(frozen=True) | @dataclass(frozen=True) | ||||||
| class Context: | class Context: | ||||||
|     """An IAC (inter-actor communication) context. |     """An IAC (inter-actor communication) context. | ||||||
|  | @ -31,6 +207,10 @@ class Context: | ||||||
|     chan: Channel |     chan: Channel | ||||||
|     cid: str |     cid: str | ||||||
| 
 | 
 | ||||||
|  |     # TODO: should we have seperate types for caller vs. callee | ||||||
|  |     # side contexts? The caller always opens a portal whereas the callee | ||||||
|  |     # is always responding back through a context-stream | ||||||
|  | 
 | ||||||
|     # only set on the caller side |     # only set on the caller side | ||||||
|     _portal: Optional['Portal'] = None    # type: ignore # noqa |     _portal: Optional['Portal'] = None    # type: ignore # noqa | ||||||
| 
 | 
 | ||||||
|  | @ -57,8 +237,11 @@ class Context: | ||||||
|         timeout quickly to sidestep 2-generals... |         timeout quickly to sidestep 2-generals... | ||||||
| 
 | 
 | ||||||
|         """ |         """ | ||||||
|         assert self._portal, ( |         if self._portal:  # caller side: | ||||||
|             "No portal found, this is likely a callee side context") |             if not self._portal: | ||||||
|  |                 raise RuntimeError( | ||||||
|  |                     "No portal found, this is likely a callee side context" | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|             cid = self.cid |             cid = self.cid | ||||||
|             with trio.move_on_after(0.5) as cs: |             with trio.move_on_after(0.5) as cs: | ||||||
|  | @ -76,27 +259,75 @@ class Context: | ||||||
|                 # XXX: there's no way to know if the remote task was indeed |                 # XXX: there's no way to know if the remote task was indeed | ||||||
|                 # cancelled in the case where the connection is broken or |                 # cancelled in the case where the connection is broken or | ||||||
|                 # some other network error occurred. |                 # some other network error occurred. | ||||||
|             if not self._portal.channel.connected(): |                 # if not self._portal.channel.connected(): | ||||||
|  |                 if not self.chan.connected(): | ||||||
|                     log.warning( |                     log.warning( | ||||||
|                         "May have failed to cancel remote task " |                         "May have failed to cancel remote task " | ||||||
|                         f"{cid} for {self._portal.channel.uid}") |                         f"{cid} for {self._portal.channel.uid}") | ||||||
|  |         else: | ||||||
|  |             # ensure callee side | ||||||
|  |             assert self._cancel_scope | ||||||
|  |             # TODO: should we have an explicit cancel message | ||||||
|  |             # or is relaying the local `trio.Cancelled` as an | ||||||
|  |             # {'error': trio.Cancelled, cid: "blah"} enough? | ||||||
|  |             # This probably gets into the discussion in | ||||||
|  |             # https://github.com/goodboy/tractor/issues/36 | ||||||
|  |             self._cancel_scope.cancel() | ||||||
| 
 | 
 | ||||||
|  |     # TODO: do we need a restart api? | ||||||
|     # async def restart(self) -> None: |     # async def restart(self) -> None: | ||||||
|     #     # TODO |  | ||||||
|     #     pass |     #     pass | ||||||
| 
 | 
 | ||||||
|     # @asynccontextmanager |     @asynccontextmanager | ||||||
|     # async def open_stream( |     async def open_stream( | ||||||
|     #     self, |         self, | ||||||
|     # ) -> AsyncContextManager: |     ) -> MsgStream: | ||||||
|     #     # TODO |         # TODO | ||||||
|     #     pass | 
 | ||||||
|  |         actor = current_actor() | ||||||
|  | 
 | ||||||
|  |         # here we create a mem chan that corresponds to the | ||||||
|  |         # far end caller / callee. | ||||||
|  | 
 | ||||||
|  |         # NOTE: in one way streaming this only happens on the | ||||||
|  |         # caller side inside `Actor.send_cmd()` so if you try | ||||||
|  |         # to send a stop from the caller to the callee in the | ||||||
|  |         # single-direction-stream case you'll get a lookup error | ||||||
|  |         # currently. | ||||||
|  |         _, recv_chan = actor.get_memchans( | ||||||
|  |             self.chan.uid, | ||||||
|  |             self.cid | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         async with MsgStream(ctx=self, rx_chan=recv_chan) as rchan: | ||||||
|  | 
 | ||||||
|  |             if self._portal: | ||||||
|  |                 self._portal._streams.add(rchan) | ||||||
|  | 
 | ||||||
|  |             try: | ||||||
|  |                 yield rchan | ||||||
|  | 
 | ||||||
|  |             finally: | ||||||
|  |                 await self.send_stop() | ||||||
|  |                 if self._portal: | ||||||
|  |                     self._portal._streams.add(rchan) | ||||||
|  | 
 | ||||||
|  |     async def started(self, value: Any) -> None: | ||||||
|  | 
 | ||||||
|  |         if self._portal: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f"Caller side context {self} can not call started!") | ||||||
|  | 
 | ||||||
|  |         await self.chan.send({'started': value, 'cid': self.cid}) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def stream(func): | def stream(func: Callable) -> Callable: | ||||||
|     """Mark an async function as a streaming routine with ``@stream``. |     """Mark an async function as a streaming routine with ``@stream``. | ||||||
|  | 
 | ||||||
|     """ |     """ | ||||||
|  |     # annotate | ||||||
|     func._tractor_stream_function = True |     func._tractor_stream_function = True | ||||||
|  | 
 | ||||||
|     sig = inspect.signature(func) |     sig = inspect.signature(func) | ||||||
|     params = sig.parameters |     params = sig.parameters | ||||||
|     if 'stream' not in params and 'ctx' in params: |     if 'stream' not in params and 'ctx' in params: | ||||||
|  | @ -114,147 +345,24 @@ def stream(func): | ||||||
|     ): |     ): | ||||||
|         raise TypeError( |         raise TypeError( | ||||||
|             "The first argument to the stream function " |             "The first argument to the stream function " | ||||||
|             f"{func.__name__} must be `ctx: tractor.Context`" |             f"{func.__name__} must be `ctx: tractor.Context` " | ||||||
|  |             "(Or ``to_trio`` if using ``asyncio`` in guest mode)." | ||||||
|         ) |         ) | ||||||
|     return func |     return func | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ReceiveMsgStream(trio.abc.ReceiveChannel): | def context(func: Callable) -> Callable: | ||||||
|     """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with |     """Mark an async function as a streaming routine with ``@context``. | ||||||
|     special behaviour for signalling stream termination across an |  | ||||||
|     inter-actor ``Channel``. This is the type returned to a local task |  | ||||||
|     which invoked a remote streaming function using `Portal.run()`. |  | ||||||
| 
 |  | ||||||
|     Termination rules: |  | ||||||
|     - if the local task signals stop iteration a cancel signal is |  | ||||||
|       relayed to the remote task indicating to stop streaming |  | ||||||
|     - if the remote task signals the end of a stream, raise a |  | ||||||
|       ``StopAsyncIteration`` to terminate the local ``async for`` |  | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|     def __init__( |     # annotate | ||||||
|         self, |     func._tractor_context_function = True | ||||||
|         ctx: Context, |  | ||||||
|         rx_chan: trio.abc.ReceiveChannel, |  | ||||||
|         portal: 'Portal',  # type: ignore # noqa |  | ||||||
|     ) -> None: |  | ||||||
|         self._ctx = ctx |  | ||||||
|         self._rx_chan = rx_chan |  | ||||||
|         self._portal = portal |  | ||||||
|         self._shielded = False |  | ||||||
| 
 | 
 | ||||||
|     # delegate directly to underlying mem channel |     sig = inspect.signature(func) | ||||||
|     def receive_nowait(self): |     params = sig.parameters | ||||||
|         return self._rx_chan.receive_nowait() |     if 'ctx' not in params: | ||||||
| 
 |         raise TypeError( | ||||||
|     async def receive(self): |             "The first argument to the context function " | ||||||
|         try: |             f"{func.__name__} must be `ctx: tractor.Context`" | ||||||
|             msg = await self._rx_chan.receive() |         ) | ||||||
|             return msg['yield'] |     return func | ||||||
| 
 |  | ||||||
|         except KeyError: |  | ||||||
|             # internal error should never get here |  | ||||||
|             assert msg.get('cid'), ("Received internal error at portal?") |  | ||||||
| 
 |  | ||||||
|             # TODO: handle 2 cases with 3.10 match syntax |  | ||||||
|             # - 'stop' |  | ||||||
|             # - 'error' |  | ||||||
|             # possibly just handle msg['stop'] here! |  | ||||||
| 
 |  | ||||||
|             # TODO: test that shows stream raising an expected error!!! |  | ||||||
|             if msg.get('error'): |  | ||||||
|                 # raise the error message |  | ||||||
|                 raise unpack_error(msg, self._portal.channel) |  | ||||||
| 
 |  | ||||||
|         except (trio.ClosedResourceError, StopAsyncIteration): |  | ||||||
|             # XXX: this indicates that a `stop` message was |  | ||||||
|             # sent by the far side of the underlying channel. |  | ||||||
|             # Currently this is triggered by calling ``.aclose()`` on |  | ||||||
|             # the send side of the channel inside |  | ||||||
|             # ``Actor._push_result()``, but maybe it should be put here? |  | ||||||
|             # to avoid exposing the internal mem chan closing mechanism? |  | ||||||
|             # in theory we could instead do some flushing of the channel |  | ||||||
|             # if needed to ensure all consumers are complete before |  | ||||||
|             # triggering closure too early? |  | ||||||
| 
 |  | ||||||
|             # Locally, we want to close this stream gracefully, by |  | ||||||
|             # terminating any local consumers tasks deterministically. |  | ||||||
|             # We **don't** want to be closing this send channel and not |  | ||||||
|             # relaying a final value to remaining consumers who may not |  | ||||||
|             # have been scheduled to receive it yet? |  | ||||||
| 
 |  | ||||||
|             # lots of testing to do here |  | ||||||
| 
 |  | ||||||
|             # when the send is closed we assume the stream has |  | ||||||
|             # terminated and signal this local iterator to stop |  | ||||||
|             await self.aclose() |  | ||||||
|             raise StopAsyncIteration |  | ||||||
| 
 |  | ||||||
|         except trio.Cancelled: |  | ||||||
|             # relay cancels to the remote task |  | ||||||
|             await self.aclose() |  | ||||||
|             raise |  | ||||||
| 
 |  | ||||||
|     @contextmanager |  | ||||||
|     def shield( |  | ||||||
|         self |  | ||||||
|     ) -> Iterator['ReceiveMsgStream']:  # noqa |  | ||||||
|         """Shield this stream's underlying channel such that a local consumer task |  | ||||||
|         can be cancelled (and possibly restarted) using ``trio.Cancelled``. |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
|         self._shielded = True |  | ||||||
|         yield self |  | ||||||
|         self._shielded = False |  | ||||||
| 
 |  | ||||||
|     async def aclose(self): |  | ||||||
|         """Cancel associated remote actor task and local memory channel |  | ||||||
|         on close. |  | ||||||
|         """ |  | ||||||
|         rx_chan = self._rx_chan |  | ||||||
| 
 |  | ||||||
|         if rx_chan._closed: |  | ||||||
|             log.warning(f"{self} is already closed") |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         # stats = rx_chan.statistics() |  | ||||||
|         # if stats.open_receive_channels > 1: |  | ||||||
|         #     # if we've been cloned don't kill the stream |  | ||||||
|         #     log.debug( |  | ||||||
|         #       "there are still consumers running keeping stream alive") |  | ||||||
|         #     return |  | ||||||
| 
 |  | ||||||
|         if self._shielded: |  | ||||||
|             log.warning(f"{self} is shielded, portal channel being kept alive") |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         # close the local mem chan |  | ||||||
|         rx_chan.close() |  | ||||||
| 
 |  | ||||||
|         # cancel surrounding IPC context |  | ||||||
|         await self._ctx.cancel() |  | ||||||
| 
 |  | ||||||
|     # TODO: but make it broadcasting to consumers |  | ||||||
|     # def clone(self): |  | ||||||
|     #     """Clone this receive channel allowing for multi-task |  | ||||||
|     #     consumption from the same channel. |  | ||||||
| 
 |  | ||||||
|     #     """ |  | ||||||
|     #     return ReceiveStream( |  | ||||||
|     #         self._cid, |  | ||||||
|     #         self._rx_chan.clone(), |  | ||||||
|     #         self._portal, |  | ||||||
|     #     ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # class MsgStream(ReceiveMsgStream, trio.abc.Channel): |  | ||||||
| #     """ |  | ||||||
| #     Bidirectional message stream for use within an inter-actor actor |  | ||||||
| #     ``Context```. |  | ||||||
| 
 |  | ||||||
| #     """ |  | ||||||
| #     async def send( |  | ||||||
| #         self, |  | ||||||
| #         data: Any |  | ||||||
| #     ) -> None: |  | ||||||
| #         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue