Add a way to shield a stream's underlying channel

Add a ``tractor._portal.StreamReceiveChannel.shield_channel()`` context
manager which allows for avoiding the closing of an IPC stream's
underlying channel for the purposes of task re-spawning. Sometimes you
might want to cancel a task consuming a stream but not tear down the IPC
between actors (the default). A common use can might be where the task's
"setup" work might need to be redone but you want to keep the
established portal / channel in tact despite the task restart.

Includes a test.
stream_channel_shield
Tyler Goodlet 2020-12-16 21:42:28 -05:00
parent a510eb0b2b
commit 15ead6b561
2 changed files with 93 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import platform
import trio
import tractor
from tractor.testing import tractor_test
import pytest
@ -53,6 +54,7 @@ async def stream_from_single_subactor(stream_func_name):
"""Verify we can spawn a daemon actor and retrieve streamed data.
"""
async with tractor.find_actor('streamerd') as portals:
if not portals:
# only one per host address, spawns an actor if None
async with tractor.open_nursery() as nursery:
@ -73,8 +75,10 @@ async def stream_from_single_subactor(stream_func_name):
# it'd sure be nice to have an asyncitertools here...
iseq = iter(seq)
ival = next(iseq)
async for val in stream:
assert val == ival
try:
ival = next(iseq)
except StopIteration:
@ -83,6 +87,7 @@ async def stream_from_single_subactor(stream_func_name):
await stream.aclose()
await trio.sleep(0.3)
try:
await stream.__anext__()
except StopAsyncIteration:
@ -109,8 +114,11 @@ def test_stream_from_single_subactor(arb_addr, start_method, stream_func):
# this is the first 2 actors, streamer_1 and streamer_2
async def stream_data(seed):
for i in range(seed):
yield i
# trigger scheduler to simulate practical usage
await trio.sleep(0)
@ -246,3 +254,68 @@ def test_not_fast_enough_quad(
else:
# should be cancelled mid-streaming
assert results is None
@tractor_test
async def test_respawn_consumer_task(
arb_addr,
spawn_backend,
loglevel,
):
"""Verify that ``._portal.StreamReceiveChannel.shield_channel()``
sucessfully protects the underlying IPC channel from being closed
when cancelling and respawning a consumer task.
This also serves to verify that all values from the stream can be
received despite the respawns.
"""
stream = None
async with tractor.open_nursery() as n:
stream = await(await n.run_in_actor(
'streamer',
stream_data,
seed=11,
)).result()
expect = set(range(11))
received = []
# this is the re-spawn task routine
async def consume(task_status=trio.TASK_STATUS_IGNORED):
print('starting consume task..')
nonlocal stream
with trio.CancelScope() as cs:
task_status.started(cs)
# shield stream's underlying channel from cancellation
with stream.shield_channel():
async for v in stream:
print(f'from stream: {v}')
expect.remove(v)
received.append(v)
print('exited consume')
async with trio.open_nursery() as ln:
cs = await ln.start(consume)
while True:
await trio.sleep(0.1)
if received[-1] % 2 == 0:
print('cancelling consume task..')
cs.cancel()
# respawn
cs = await ln.start(consume)
if not expect:
print("all values streamed, BREAKING")
break

View File

@ -7,6 +7,7 @@ import typing
from typing import Tuple, Any, Dict, Optional, Set
from functools import partial
from dataclasses import dataclass
from contextlib import contextmanager
import trio
from async_generator import asynccontextmanager
@ -59,6 +60,7 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
self._cid = cid
self._rx_chan = rx_chan
self._portal = portal
self._shielded = False
# delegate directly to underlying mem channel
def receive_nowait(self):
@ -83,6 +85,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
"Received internal error at portal?")
raise unpack_error(msg, self._portal.channel)
@contextmanager
def shield_channel(
self
) -> typing.AsyncGenerator['StreamReceiveChannel', None]:
"""Shield this stream's underlying channel such that a local consumer task
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
"""
self._shielded = True
yield self
self._shielded = False
async def aclose(self):
"""Cancel associated remote actor task and local memory channel
on close.
@ -90,12 +104,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
if self._rx_chan._closed:
log.warning(f"{self} is already closed")
return
if self._shielded:
log.warning(f"{self} is shielded, portal channel being kept alive")
return
cid = self._cid
with trio.move_on_after(0.5) as cs:
cs.shield = True
log.warning(
f"Cancelling stream {cid} to "
f"{self._portal.channel.uid}")
# NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly.