Merge pull request #278 from goodboy/end_of_channel_fixes
End of channel fixes for streams and broadcastingnew_mypy
commit
2d6fbd5437
|
@ -0,0 +1,12 @@
|
|||
Repair inter-actor stream closure semantics to work correctly with
|
||||
``tractor.trionics.BroadcastReceiver`` task fan out usage.
|
||||
|
||||
A set of previously unknown bugs discovered in `257
|
||||
<https://github.com/goodboy/tractor/pull/257>`_ let graceful stream
|
||||
closure result in hanging consumer tasks that use the broadcast APIs.
|
||||
This adds better internal closure state tracking to the broadcast
|
||||
receiver and message stream APIs and in particular ensures that when an
|
||||
underlying stream/receive-channel (a broadcast receiver is receiving
|
||||
from) is closed, all consumer tasks waiting on that underlying channel
|
||||
are woken so they can receive the ``trio.EndOfChannel`` signal and
|
||||
promptly terminate.
|
|
@ -1,7 +1,7 @@
|
|||
pytest
|
||||
pytest-trio
|
||||
pdbpp
|
||||
mypy
|
||||
mypy<0.920
|
||||
trio_typing
|
||||
pexpect
|
||||
towncrier
|
||||
towncrier
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""
|
||||
'''
|
||||
Advanced streaming patterns using bidirectional streams and contexts.
|
||||
|
||||
"""
|
||||
'''
|
||||
from collections import Counter
|
||||
import itertools
|
||||
from typing import Set, Dict, List
|
||||
|
||||
|
@ -269,3 +270,98 @@ def test_sigint_both_stream_types():
|
|||
assert 0, "Didn't receive KBI!?"
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def inf_streamer(
|
||||
ctx: tractor.Context,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Stream increasing ints until terminated with a 'done' msg.
|
||||
|
||||
'''
|
||||
await ctx.started()
|
||||
|
||||
async with (
|
||||
ctx.open_stream() as stream,
|
||||
trio.open_nursery() as n,
|
||||
):
|
||||
async def bail_on_sentinel():
|
||||
async for msg in stream:
|
||||
if msg == 'done':
|
||||
await stream.aclose()
|
||||
else:
|
||||
print(f'streamer received {msg}')
|
||||
|
||||
# start termination detector
|
||||
n.start_soon(bail_on_sentinel)
|
||||
|
||||
for val in itertools.count():
|
||||
try:
|
||||
await stream.send(val)
|
||||
except trio.ClosedResourceError:
|
||||
# close out the stream gracefully
|
||||
break
|
||||
|
||||
print('terminating streamer')
|
||||
|
||||
|
||||
def test_local_task_fanout_from_stream():
|
||||
'''
|
||||
Single stream with multiple local consumer tasks using the
|
||||
``MsgStream.subscribe()` api.
|
||||
|
||||
Ensure all tasks receive all values after stream completes sending.
|
||||
|
||||
'''
|
||||
consumers = 22
|
||||
|
||||
async def main():
|
||||
|
||||
counts = Counter()
|
||||
|
||||
async with tractor.open_nursery() as tn:
|
||||
p = await tn.start_actor(
|
||||
'inf_streamer',
|
||||
enable_modules=[__name__],
|
||||
)
|
||||
async with (
|
||||
p.open_context(inf_streamer) as (ctx, _),
|
||||
ctx.open_stream() as stream,
|
||||
):
|
||||
|
||||
async def pull_and_count(name: str):
|
||||
# name = trio.lowlevel.current_task().name
|
||||
async with stream.subscribe() as recver:
|
||||
assert isinstance(
|
||||
recver,
|
||||
tractor.trionics.BroadcastReceiver
|
||||
)
|
||||
async for val in recver:
|
||||
# print(f'{name}: {val}')
|
||||
counts[name] += 1
|
||||
|
||||
print(f'{name} bcaster ended')
|
||||
|
||||
print(f'{name} completed')
|
||||
|
||||
with trio.fail_after(3):
|
||||
async with trio.open_nursery() as nurse:
|
||||
for i in range(consumers):
|
||||
nurse.start_soon(pull_and_count, i)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
print('\nterminating')
|
||||
await stream.send('done')
|
||||
|
||||
print('closed stream connection')
|
||||
|
||||
assert len(counts) == consumers
|
||||
mx = max(counts.values())
|
||||
# make sure each task received all stream values
|
||||
assert all(val == mx for val in counts.values())
|
||||
|
||||
await p.cancel_actor()
|
||||
|
||||
trio.run(main)
|
||||
|
|
|
@ -79,33 +79,36 @@ async def stream_from_single_subactor(
|
|||
|
||||
seq = range(10)
|
||||
|
||||
async with portal.open_stream_from(
|
||||
stream_func,
|
||||
sequence=list(seq), # has to be msgpack serializable
|
||||
) as stream:
|
||||
with trio.fail_after(5):
|
||||
async with portal.open_stream_from(
|
||||
stream_func,
|
||||
sequence=list(seq), # has to be msgpack serializable
|
||||
) as stream:
|
||||
|
||||
# it'd sure be nice to have an asyncitertools here...
|
||||
iseq = iter(seq)
|
||||
ival = next(iseq)
|
||||
# it'd sure be nice to have an asyncitertools here...
|
||||
iseq = iter(seq)
|
||||
ival = next(iseq)
|
||||
|
||||
async for val in stream:
|
||||
assert val == ival
|
||||
async for val in stream:
|
||||
assert val == ival
|
||||
|
||||
try:
|
||||
ival = next(iseq)
|
||||
except StopIteration:
|
||||
# should cancel far end task which will be
|
||||
# caught and no error is raised
|
||||
await stream.aclose()
|
||||
|
||||
await trio.sleep(0.3)
|
||||
|
||||
# ensure EOC signalled-state translates
|
||||
# XXX: not really sure this is correct,
|
||||
# shouldn't it be a `ClosedResourceError`?
|
||||
try:
|
||||
ival = next(iseq)
|
||||
except StopIteration:
|
||||
# should cancel far end task which will be
|
||||
# caught and no error is raised
|
||||
await stream.aclose()
|
||||
|
||||
await trio.sleep(0.3)
|
||||
|
||||
try:
|
||||
await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# stop all spawned subactors
|
||||
await portal.cancel_actor()
|
||||
# await nursery.cancel()
|
||||
await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# stop all spawned subactors
|
||||
await portal.cancel_actor()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
|
||||
# flag to denote end of stream
|
||||
self._eoc: bool = False
|
||||
self._closed: bool = False
|
||||
|
||||
# delegate directly to underlying mem channel
|
||||
def receive_nowait(self):
|
||||
|
@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
msg = await self._rx_chan.receive()
|
||||
return msg['yield']
|
||||
|
||||
except KeyError:
|
||||
except KeyError as err:
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), ("Received internal error at portal?")
|
||||
|
||||
|
@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
# - 'error'
|
||||
# possibly just handle msg['stop'] here!
|
||||
|
||||
if msg.get('stop'):
|
||||
if msg.get('stop') or self._eoc:
|
||||
log.debug(f"{self} was stopped at remote end")
|
||||
|
||||
# XXX: important to set so that a new ``.receive()``
|
||||
# call (likely by another task using a broadcast receiver)
|
||||
# doesn't accidentally pull the ``return`` message
|
||||
# value out of the underlying feed mem chan!
|
||||
self._eoc = True
|
||||
|
||||
# # when the send is closed we assume the stream has
|
||||
# # terminated and signal this local iterator to stop
|
||||
# await self.aclose()
|
||||
|
@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
# XXX: this causes ``ReceiveChannel.__anext__()`` to
|
||||
# raise a ``StopAsyncIteration`` **and** in our catch
|
||||
# block below it will trigger ``.aclose()``.
|
||||
raise trio.EndOfChannel
|
||||
raise trio.EndOfChannel from err
|
||||
|
||||
# TODO: test that shows stream raising an expected error!!!
|
||||
elif msg.get('error'):
|
||||
|
@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
raise # propagate
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
'''
|
||||
Cancel associated remote actor task and local memory channel on
|
||||
close.
|
||||
|
||||
"""
|
||||
'''
|
||||
# XXX: keep proper adherance to trio's `.aclose()` semantics:
|
||||
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
|
||||
rx_chan = self._rx_chan
|
||||
|
@ -179,6 +187,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
return
|
||||
|
||||
self._eoc = True
|
||||
self._closed = True
|
||||
|
||||
# NOTE: this is super subtle IPC messaging stuff:
|
||||
# Relay stop iteration to far end **iff** we're
|
||||
|
@ -310,15 +319,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
|||
self,
|
||||
data: Any
|
||||
) -> None:
|
||||
'''Send a message over this stream to the far end.
|
||||
'''
|
||||
Send a message over this stream to the far end.
|
||||
|
||||
'''
|
||||
# if self._eoc:
|
||||
# raise trio.ClosedResourceError('This stream is already ded')
|
||||
|
||||
if self._ctx._error:
|
||||
raise self._ctx._error # from None
|
||||
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError('This stream was already closed')
|
||||
|
||||
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
||||
|
||||
|
||||
|
|
|
@ -100,6 +100,15 @@ class BroadcastState:
|
|||
# on a newly produced value from the sender.
|
||||
recv_ready: Optional[tuple[int, trio.Event]] = None
|
||||
|
||||
# if a ``trio.EndOfChannel`` is received on any
|
||||
# consumer all consumers should be placed in this state
|
||||
# such that the group is notified of the end-of-broadcast.
|
||||
# For now, this is solely for testing/debugging purposes.
|
||||
eoc: bool = False
|
||||
|
||||
# If the broadcaster was cancelled, we might as well track it
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class BroadcastReceiver(ReceiveChannel):
|
||||
'''A memory receive channel broadcaster which is non-lossy for the
|
||||
|
@ -222,10 +231,23 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
event.set()
|
||||
return value
|
||||
|
||||
except trio.Cancelled:
|
||||
except trio.EndOfChannel:
|
||||
# if any one consumer gets an EOC from the underlying
|
||||
# receiver we need to unblock and send that signal to
|
||||
# all other consumers.
|
||||
self._state.eoc = True
|
||||
if event.statistics().tasks_waiting:
|
||||
event.set()
|
||||
raise
|
||||
|
||||
except (
|
||||
trio.Cancelled,
|
||||
):
|
||||
# handle cancelled specially otherwise sibling
|
||||
# consumers will be awoken with a sequence of -1
|
||||
# state.recv_ready = trio.Cancelled
|
||||
# and will potentially try to rewait the underlying
|
||||
# receiver instead of just cancelling immediately.
|
||||
self._state.cancelled = True
|
||||
if event.statistics().tasks_waiting:
|
||||
event.set()
|
||||
raise
|
||||
|
@ -274,11 +296,12 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
async def subscribe(
|
||||
self,
|
||||
) -> AsyncIterator[BroadcastReceiver]:
|
||||
'''Subscribe for values from this broadcast receiver.
|
||||
'''
|
||||
Subscribe for values from this broadcast receiver.
|
||||
|
||||
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
|
||||
provided at creation.
|
||||
pulls data from a clone of the original
|
||||
``trio.abc.ReceiveChannel`` provided at creation.
|
||||
|
||||
'''
|
||||
if self._closed:
|
||||
|
@ -301,7 +324,10 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
async def aclose(
|
||||
self,
|
||||
) -> None:
|
||||
'''
|
||||
Close this receiver without affecting other consumers.
|
||||
|
||||
'''
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in New Issue