forked from goodboy/tractor
Merge pull request #171 from goodboy/stream_channel_shield
Add a way to shield a stream's underlying channelfix_debug_tests_in_ci_again
commit
1701493087
|
@ -6,6 +6,7 @@ All these tests can be understood (somewhat) by running the equivalent
|
||||||
|
|
||||||
TODO: None of these tests have been run successfully on windows yet.
|
TODO: None of these tests have been run successfully on windows yet.
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -41,10 +42,16 @@ def mk_cmd(ex_name: str) -> str:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def spawn(
|
def spawn(
|
||||||
|
start_method,
|
||||||
testdir,
|
testdir,
|
||||||
arb_addr,
|
arb_addr,
|
||||||
) -> 'pexpect.spawn':
|
) -> 'pexpect.spawn':
|
||||||
|
|
||||||
|
if start_method != 'trio':
|
||||||
|
pytest.skip(
|
||||||
|
"Debugger tests are only supported on the trio backend"
|
||||||
|
)
|
||||||
|
|
||||||
def _spawn(cmd):
|
def _spawn(cmd):
|
||||||
return testdir.spawn(
|
return testdir.spawn(
|
||||||
cmd=mk_cmd(cmd),
|
cmd=mk_cmd(cmd),
|
||||||
|
@ -370,6 +377,8 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method
|
||||||
child has unblocked (which can happen when it has the tty lock and
|
child has unblocked (which can happen when it has the tty lock and
|
||||||
is engaged in pdb) it is indeed cancelled after exiting the debugger.
|
is engaged in pdb) it is indeed cancelled after exiting the debugger.
|
||||||
"""
|
"""
|
||||||
|
timed_out_early = False
|
||||||
|
|
||||||
child = spawn('root_cancelled_but_child_is_in_tty_lock')
|
child = spawn('root_cancelled_but_child_is_in_tty_lock')
|
||||||
|
|
||||||
child.expect(r"\(Pdb\+\+\)")
|
child.expect(r"\(Pdb\+\+\)")
|
||||||
|
@ -377,9 +386,13 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method
|
||||||
before = str(child.before.decode())
|
before = str(child.before.decode())
|
||||||
assert "NameError: name 'doggypants' is not defined" in before
|
assert "NameError: name 'doggypants' is not defined" in before
|
||||||
assert "tractor._exceptions.RemoteActorError: ('name_error'" not in before
|
assert "tractor._exceptions.RemoteActorError: ('name_error'" not in before
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
child.sendline('c')
|
child.sendline('c')
|
||||||
|
|
||||||
for _ in range(4):
|
|
||||||
|
for i in range(4):
|
||||||
|
time.sleep(0.5)
|
||||||
try:
|
try:
|
||||||
child.expect(r"\(Pdb\+\+\)")
|
child.expect(r"\(Pdb\+\+\)")
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
@ -390,12 +403,25 @@ def test_root_nursery_cancels_before_child_releases_tty_lock(spawn, start_method
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
except pexpect.exceptions.EOF:
|
||||||
|
print(f"Failed early on {i}?")
|
||||||
|
before = str(child.before.decode())
|
||||||
|
|
||||||
|
timed_out_early = True
|
||||||
|
|
||||||
|
# race conditions on how fast the continue is sent?
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
before = str(child.before.decode())
|
before = str(child.before.decode())
|
||||||
assert "NameError: name 'doggypants' is not defined" in before
|
assert "NameError: name 'doggypants' is not defined" in before
|
||||||
|
|
||||||
child.sendline('c')
|
child.sendline('c')
|
||||||
|
|
||||||
child.expect(pexpect.EOF)
|
child.expect(pexpect.EOF)
|
||||||
|
|
||||||
|
if not timed_out_early:
|
||||||
|
|
||||||
before = str(child.before.decode())
|
before = str(child.before.decode())
|
||||||
assert "tractor._exceptions.RemoteActorError: ('spawner0'" in before
|
assert "tractor._exceptions.RemoteActorError: ('spawner0'" in before
|
||||||
assert "tractor._exceptions.RemoteActorError: ('name_error'" in before
|
assert "tractor._exceptions.RemoteActorError: ('name_error'" in before
|
||||||
|
|
|
@ -7,6 +7,7 @@ import platform
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
import tractor
|
||||||
|
from tractor.testing import tractor_test
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +54,7 @@ async def stream_from_single_subactor(stream_func_name):
|
||||||
"""Verify we can spawn a daemon actor and retrieve streamed data.
|
"""Verify we can spawn a daemon actor and retrieve streamed data.
|
||||||
"""
|
"""
|
||||||
async with tractor.find_actor('streamerd') as portals:
|
async with tractor.find_actor('streamerd') as portals:
|
||||||
|
|
||||||
if not portals:
|
if not portals:
|
||||||
# only one per host address, spawns an actor if None
|
# only one per host address, spawns an actor if None
|
||||||
async with tractor.open_nursery() as nursery:
|
async with tractor.open_nursery() as nursery:
|
||||||
|
@ -73,8 +75,10 @@ async def stream_from_single_subactor(stream_func_name):
|
||||||
# it'd sure be nice to have an asyncitertools here...
|
# it'd sure be nice to have an asyncitertools here...
|
||||||
iseq = iter(seq)
|
iseq = iter(seq)
|
||||||
ival = next(iseq)
|
ival = next(iseq)
|
||||||
|
|
||||||
async for val in stream:
|
async for val in stream:
|
||||||
assert val == ival
|
assert val == ival
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ival = next(iseq)
|
ival = next(iseq)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
@ -83,6 +87,7 @@ async def stream_from_single_subactor(stream_func_name):
|
||||||
await stream.aclose()
|
await stream.aclose()
|
||||||
|
|
||||||
await trio.sleep(0.3)
|
await trio.sleep(0.3)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await stream.__anext__()
|
await stream.__anext__()
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
|
@ -109,8 +114,11 @@ def test_stream_from_single_subactor(arb_addr, start_method, stream_func):
|
||||||
|
|
||||||
# this is the first 2 actors, streamer_1 and streamer_2
|
# this is the first 2 actors, streamer_1 and streamer_2
|
||||||
async def stream_data(seed):
|
async def stream_data(seed):
|
||||||
|
|
||||||
for i in range(seed):
|
for i in range(seed):
|
||||||
|
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
# trigger scheduler to simulate practical usage
|
# trigger scheduler to simulate practical usage
|
||||||
await trio.sleep(0)
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
@ -246,3 +254,68 @@ def test_not_fast_enough_quad(
|
||||||
else:
|
else:
|
||||||
# should be cancelled mid-streaming
|
# should be cancelled mid-streaming
|
||||||
assert results is None
|
assert results is None
|
||||||
|
|
||||||
|
|
||||||
|
@tractor_test
|
||||||
|
async def test_respawn_consumer_task(
|
||||||
|
arb_addr,
|
||||||
|
spawn_backend,
|
||||||
|
loglevel,
|
||||||
|
):
|
||||||
|
"""Verify that ``._portal.ReceiveStream.shield()``
|
||||||
|
sucessfully protects the underlying IPC channel from being closed
|
||||||
|
when cancelling and respawning a consumer task.
|
||||||
|
|
||||||
|
This also serves to verify that all values from the stream can be
|
||||||
|
received despite the respawns.
|
||||||
|
|
||||||
|
"""
|
||||||
|
stream = None
|
||||||
|
|
||||||
|
async with tractor.open_nursery() as n:
|
||||||
|
|
||||||
|
stream = await(await n.run_in_actor(
|
||||||
|
'streamer',
|
||||||
|
stream_data,
|
||||||
|
seed=11,
|
||||||
|
)).result()
|
||||||
|
|
||||||
|
expect = set(range(11))
|
||||||
|
received = []
|
||||||
|
|
||||||
|
# this is the re-spawn task routine
|
||||||
|
async def consume(task_status=trio.TASK_STATUS_IGNORED):
|
||||||
|
print('starting consume task..')
|
||||||
|
nonlocal stream
|
||||||
|
|
||||||
|
with trio.CancelScope() as cs:
|
||||||
|
task_status.started(cs)
|
||||||
|
|
||||||
|
# shield stream's underlying channel from cancellation
|
||||||
|
with stream.shield():
|
||||||
|
|
||||||
|
async for v in stream:
|
||||||
|
print(f'from stream: {v}')
|
||||||
|
expect.remove(v)
|
||||||
|
received.append(v)
|
||||||
|
|
||||||
|
print('exited consume')
|
||||||
|
|
||||||
|
async with trio.open_nursery() as ln:
|
||||||
|
cs = await ln.start(consume)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
if received[-1] % 2 == 0:
|
||||||
|
|
||||||
|
print('cancelling consume task..')
|
||||||
|
cs.cancel()
|
||||||
|
|
||||||
|
# respawn
|
||||||
|
cs = await ln.start(consume)
|
||||||
|
|
||||||
|
if not expect:
|
||||||
|
print("all values streamed, BREAKING")
|
||||||
|
break
|
||||||
|
|
|
@ -214,9 +214,12 @@ class Channel:
|
||||||
# # time is pointless
|
# # time is pointless
|
||||||
# await self.msgstream.send(sent)
|
# await self.msgstream.send(sent)
|
||||||
except trio.BrokenResourceError:
|
except trio.BrokenResourceError:
|
||||||
|
|
||||||
if not self._autorecon:
|
if not self._autorecon:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
await self.aclose()
|
await self.aclose()
|
||||||
|
|
||||||
if self._autorecon: # attempt reconnect
|
if self._autorecon: # attempt reconnect
|
||||||
await self._reconnect()
|
await self._reconnect()
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -4,9 +4,10 @@ Portal api
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
from typing import Tuple, Any, Dict, Optional, Set
|
from typing import Tuple, Any, Dict, Optional, Set, Iterator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
|
@ -37,7 +38,7 @@ async def maybe_open_nursery(
|
||||||
yield nursery
|
yield nursery
|
||||||
|
|
||||||
|
|
||||||
class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
class ReceiveStream(trio.abc.ReceiveChannel):
|
||||||
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
"""A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
|
||||||
special behaviour for signalling stream termination across an
|
special behaviour for signalling stream termination across an
|
||||||
inter-actor ``Channel``. This is the type returned to a local task
|
inter-actor ``Channel``. This is the type returned to a local task
|
||||||
|
@ -59,6 +60,7 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
self._cid = cid
|
self._cid = cid
|
||||||
self._rx_chan = rx_chan
|
self._rx_chan = rx_chan
|
||||||
self._portal = portal
|
self._portal = portal
|
||||||
|
self._shielded = False
|
||||||
|
|
||||||
# delegate directly to underlying mem channel
|
# delegate directly to underlying mem channel
|
||||||
def receive_nowait(self):
|
def receive_nowait(self):
|
||||||
|
@ -83,6 +85,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
"Received internal error at portal?")
|
"Received internal error at portal?")
|
||||||
raise unpack_error(msg, self._portal.channel)
|
raise unpack_error(msg, self._portal.channel)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def shield(
|
||||||
|
self
|
||||||
|
) -> Iterator['ReceiveStream']: # noqa
|
||||||
|
"""Shield this stream's underlying channel such that a local consumer task
|
||||||
|
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._shielded = True
|
||||||
|
yield self
|
||||||
|
self._shielded = False
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
"""Cancel associated remote actor task and local memory channel
|
"""Cancel associated remote actor task and local memory channel
|
||||||
on close.
|
on close.
|
||||||
|
@ -90,12 +104,18 @@ class StreamReceiveChannel(trio.abc.ReceiveChannel):
|
||||||
if self._rx_chan._closed:
|
if self._rx_chan._closed:
|
||||||
log.warning(f"{self} is already closed")
|
log.warning(f"{self} is already closed")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._shielded:
|
||||||
|
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||||
|
return
|
||||||
|
|
||||||
cid = self._cid
|
cid = self._cid
|
||||||
with trio.move_on_after(0.5) as cs:
|
with trio.move_on_after(0.5) as cs:
|
||||||
cs.shield = True
|
cs.shield = True
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Cancelling stream {cid} to "
|
f"Cancelling stream {cid} to "
|
||||||
f"{self._portal.channel.uid}")
|
f"{self._portal.channel.uid}")
|
||||||
|
|
||||||
# NOTE: we're telling the far end actor to cancel a task
|
# NOTE: we're telling the far end actor to cancel a task
|
||||||
# corresponding to *this actor*. The far end local channel
|
# corresponding to *this actor*. The far end local channel
|
||||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||||
|
@ -136,7 +156,7 @@ class Portal:
|
||||||
self._expect_result: Optional[
|
self._expect_result: Optional[
|
||||||
Tuple[str, Any, str, Dict[str, Any]]
|
Tuple[str, Any, str, Dict[str, Any]]
|
||||||
] = None
|
] = None
|
||||||
self._streams: Set[StreamReceiveChannel] = set()
|
self._streams: Set[ReceiveStream] = set()
|
||||||
self.actor = current_actor()
|
self.actor = current_actor()
|
||||||
|
|
||||||
async def _submit(
|
async def _submit(
|
||||||
|
@ -199,7 +219,7 @@ class Portal:
|
||||||
# to make async-generators the fundamental IPC API over channels!
|
# to make async-generators the fundamental IPC API over channels!
|
||||||
# (think `yield from`, `gen.send()`, and functional reactive stuff)
|
# (think `yield from`, `gen.send()`, and functional reactive stuff)
|
||||||
if resptype == 'yield': # stream response
|
if resptype == 'yield': # stream response
|
||||||
rchan = StreamReceiveChannel(cid, recv_chan, self)
|
rchan = ReceiveStream(cid, recv_chan, self)
|
||||||
self._streams.add(rchan)
|
self._streams.add(rchan)
|
||||||
return rchan
|
return rchan
|
||||||
|
|
||||||
|
@ -302,7 +322,7 @@ class LocalPortal:
|
||||||
A compatibility shim for normal portals but for invoking functions
|
A compatibility shim for normal portals but for invoking functions
|
||||||
using an in process actor instance.
|
using an in process actor instance.
|
||||||
"""
|
"""
|
||||||
actor: 'Actor' # type: ignore
|
actor: 'Actor' # type: ignore # noqa
|
||||||
channel: Channel
|
channel: Channel
|
||||||
|
|
||||||
async def run(self, ns: str, func_name: str, **kwargs) -> Any:
|
async def run(self, ns: str, func_name: str, **kwargs) -> Any:
|
||||||
|
|
Loading…
Reference in New Issue