diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7ea4e0..a64515f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -262,7 +262,7 @@ async def test_respawn_consumer_task( spawn_backend, loglevel, ): - """Verify that ``._portal.StreamReceiveChannel.shield_channel()`` + """Verify that ``._portal.ReceiveStream.shield()`` sucessfully protects the underlying IPC channel from being closed when cancelling and respawning a consumer task. @@ -292,7 +292,7 @@ async def test_respawn_consumer_task( task_status.started(cs) # shield stream's underlying channel from cancellation - with stream.shield_channel(): + with stream.shield(): async for v in stream: print(f'from stream: {v}') diff --git a/tractor/_portal.py b/tractor/_portal.py index 50036a8..15d86e3 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -4,7 +4,7 @@ Portal api import importlib import inspect import typing -from typing import Tuple, Any, Dict, Optional, Set +from typing import Tuple, Any, Dict, Optional, Set, Iterator from functools import partial from dataclasses import dataclass from contextlib import contextmanager @@ -38,7 +38,7 @@ async def maybe_open_nursery( yield nursery -class StreamReceiveChannel(trio.abc.ReceiveChannel): +class ReceiveStream(trio.abc.ReceiveChannel): """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with special behaviour for signalling stream termination across an inter-actor ``Channel``. This is the type returned to a local task @@ -86,9 +86,9 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel): raise unpack_error(msg, self._portal.channel) @contextmanager - def shield_channel( + def shield( self - ) -> typing.AsyncGenerator['StreamReceiveChannel', None]: + ) -> Iterator['ReceiveStream']: # noqa """Shield this stream's underlying channel such that a local consumer task can be cancelled (and possibly restarted) using ``trio.Cancelled``. @@ -156,7 +156,7 @@ class Portal: self._expect_result: Optional[ Tuple[str, Any, str, Dict[str, Any]] ] = None - self._streams: Set[StreamReceiveChannel] = set() + self._streams: Set[ReceiveStream] = set() self.actor = current_actor() async def _submit( @@ -219,7 +219,7 @@ class Portal: # to make async-generators the fundamental IPC API over channels! # (think `yield from`, `gen.send()`, and functional reactive stuff) if resptype == 'yield': # stream response - rchan = StreamReceiveChannel(cid, recv_chan, self) + rchan = ReceiveStream(cid, recv_chan, self) self._streams.add(rchan) return rchan @@ -322,7 +322,7 @@ class LocalPortal: A compatibility shim for normal portals but for invoking functions using an in process actor instance. """ - actor: 'Actor' # type: ignore + actor: 'Actor' # type: ignore # noqa channel: Channel async def run(self, ns: str, func_name: str, **kwargs) -> Any: