Add a way to shield a stream's underlying channel

Add a ``tractor._portal.StreamReceiveChannel.shield_channel()`` context
manager which allows for avoiding the closing of an IPC stream's
underlying channel for the purposes of task re-spawning. Sometimes you
might want to cancel a task consuming a stream but not tear down the IPC
between actors (the default). A common use can might be where the task's
"setup" work might need to be redone but you want to keep the
established portal / channel in tact despite the task restart.

Includes a test.
stream_channel_shield
Tyler Goodlet 2020-12-16 21:42:28 -05:00
parent a510eb0b2b
commit 15ead6b561
2 changed files with 93 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import platform
import trio import trio
import tractor import tractor
from tractor.testing import tractor_test
import pytest import pytest
@ -53,6 +54,7 @@ async def stream_from_single_subactor(stream_func_name):
"""Verify we can spawn a daemon actor and retrieve streamed data. """Verify we can spawn a daemon actor and retrieve streamed data.
""" """
async with tractor.find_actor('streamerd') as portals: async with tractor.find_actor('streamerd') as portals:
if not portals: if not portals:
# only one per host address, spawns an actor if None # only one per host address, spawns an actor if None
async with tractor.open_nursery() as nursery: async with tractor.open_nursery() as nursery:
@ -73,8 +75,10 @@ async def stream_from_single_subactor(stream_func_name):
# it'd sure be nice to have an asyncitertools here... # it'd sure be nice to have an asyncitertools here...
iseq = iter(seq) iseq = iter(seq)
ival = next(iseq) ival = next(iseq)
async for val in stream: async for val in stream:
assert val == ival assert val == ival
try: try:
ival = next(iseq) ival = next(iseq)
except StopIteration: except StopIteration:
@ -83,6 +87,7 @@ async def stream_from_single_subactor(stream_func_name):
await stream.aclose() await stream.aclose()
await trio.sleep(0.3) await trio.sleep(0.3)
try: try:
await stream.__anext__() await stream.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
@ -109,8 +114,11 @@ def test_stream_from_single_subactor(arb_addr, start_method, stream_func):
# this is the first 2 actors, streamer_1 and streamer_2 # this is the first 2 actors, streamer_1 and streamer_2
async def stream_data(seed): async def stream_data(seed):
for i in range(seed): for i in range(seed):
yield i yield i
# trigger scheduler to simulate practical usage # trigger scheduler to simulate practical usage
await trio.sleep(0) await trio.sleep(0)
@ -246,3 +254,68 @@ def test_not_fast_enough_quad(
else: else:
# should be cancelled mid-streaming # should be cancelled mid-streaming
assert results is None assert results is None
@tractor_test
async def test_respawn_consumer_task(
arb_addr,
spawn_backend,
loglevel,
):
"""Verify that ``._portal.StreamReceiveChannel.shield_channel()``
sucessfully protects the underlying IPC channel from being closed
when cancelling and respawning a consumer task.
This also serves to verify that all values from the stream can be
received despite the respawns.
"""
stream = None
async with tractor.open_nursery() as n:
stream = await(await n.run_in_actor(
'streamer',
stream_data,
seed=11,
)).result()
expect = set(range(11))
received = []
# this is the re-spawn task routine
async def consume(task_status=trio.TASK_STATUS_IGNORED):
print('starting consume task..')
nonlocal stream
with trio.CancelScope() as cs:
task_status.started(cs)
# shield stream's underlying channel from cancellation
with stream.shield_channel():
async for v in stream:
print(f'from stream: {v}')
expect.remove(v)
received.append(v)
print('exited consume')
async with trio.open_nursery() as ln:
cs = await ln.start(consume)
while True:
await trio.sleep(0.1)
if received[-1] % 2 == 0:
print('cancelling consume task..')
cs.cancel()
# respawn
cs = await ln.start(consume)
if not expect:
print("all values streamed, BREAKING")
break

View File

@ -7,6 +7,7 @@ import typing
from typing import Tuple, Any, Dict, Optional, Set from typing import Tuple, Any, Dict, Optional, Set
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from contextlib import contextmanager
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -59,6 +60,7 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
self._cid = cid self._cid = cid
self._rx_chan = rx_chan self._rx_chan = rx_chan
self._portal = portal self._portal = portal
self._shielded = False
# delegate directly to underlying mem channel # delegate directly to underlying mem channel
def receive_nowait(self): def receive_nowait(self):
@ -83,6 +85,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
"Received internal error at portal?") "Received internal error at portal?")
raise unpack_error(msg, self._portal.channel) raise unpack_error(msg, self._portal.channel)
@contextmanager
def shield_channel(
self
) -> typing.AsyncGenerator['StreamReceiveChannel', None]:
"""Shield this stream's underlying channel such that a local consumer task
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
"""
self._shielded = True
yield self
self._shielded = False
async def aclose(self): async def aclose(self):
"""Cancel associated remote actor task and local memory channel """Cancel associated remote actor task and local memory channel
on close. on close.
@ -90,12 +104,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
if self._rx_chan._closed: if self._rx_chan._closed:
log.warning(f"{self} is already closed") log.warning(f"{self} is already closed")
return return
if self._shielded:
log.warning(f"{self} is shielded, portal channel being kept alive")
return
cid = self._cid cid = self._cid
with trio.move_on_after(0.5) as cs: with trio.move_on_after(0.5) as cs:
cs.shield = True cs.shield = True
log.warning( log.warning(
f"Cancelling stream {cid} to " f"Cancelling stream {cid} to "
f"{self._portal.channel.uid}") f"{self._portal.channel.uid}")
# NOTE: we're telling the far end actor to cancel a task # NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel # corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly. # instance is passed to `Actor._cancel_task()` implicitly.