Augment test cases for callee-returns-result early

Turns out stuff was totally broken in these cases because we're either
closing the underlying mem chan too early or not handling the
"allow_overruns" mode's cancellation correctly..
proper_breakpoint_hooking
Tyler Goodlet 2023-04-13 15:14:49 -04:00
parent e16e7ca82a
commit ac51cf07b9
1 changed files with 84 additions and 38 deletions

View File

@ -234,19 +234,25 @@ def test_simple_context(
trio.run(main) trio.run(main)
@pytest.mark.parametrize(
'callee_returns_early',
[True, False],
ids=lambda item: f'callee_returns_early={item}'
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'cancel_method', 'cancel_method',
['ctx', 'portal'], ['ctx', 'portal'],
ids=lambda item: f'cancel_method={item}' ids=lambda item: f'cancel_method={item}'
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
'result_before_exit', 'chk_ctx_result_before_exit',
[True, False], [True, False],
ids=lambda item: f'result_before_exit={item}' ids=lambda item: f'chk_ctx_result_before_exit={item}'
) )
def test_caller_cancels( def test_caller_cancels(
cancel_method: str, cancel_method: str,
result_before_exit: bool, chk_ctx_result_before_exit: bool,
callee_returns_early: bool,
): ):
''' '''
Verify that when the opening side of a context (aka the caller) Verify that when the opening side of a context (aka the caller)
@ -254,13 +260,18 @@ def test_caller_cancels(
either calling `.result()` or on context exit. either calling `.result()` or on context exit.
''' '''
async def check_canceller( async def check_canceller(
ctx: tractor.Context, ctx: tractor.Context,
) -> None: ) -> None:
# should not raise yet return the remote # should not raise yet return the remote
# context cancelled error. # context cancelled error.
err = await ctx.result() res = await ctx.result()
if callee_returns_early:
assert res == 'yo'
else:
err = res
assert isinstance(err, ContextCancelled) assert isinstance(err, ContextCancelled)
assert ( assert (
tuple(err.canceller) tuple(err.canceller)
@ -274,22 +285,29 @@ def test_caller_cancels(
'simple_context', 'simple_context',
enable_modules=[__name__], enable_modules=[__name__],
) )
with trio.fail_after(0.5): timeout = 0.5 if not callee_returns_early else 2
with trio.fail_after(timeout):
async with portal.open_context( async with portal.open_context(
simple_setup_teardown, simple_setup_teardown,
data=10, data=10,
block_forever=True, block_forever=not callee_returns_early,
) as (ctx, sent): ) as (ctx, sent):
if callee_returns_early:
# ensure we block long enough before sending
# a cancel such that the callee has already
# returned it's result.
await trio.sleep(0.5)
if cancel_method == 'ctx': if cancel_method == 'ctx':
await ctx.cancel() await ctx.cancel()
else: else:
await portal.cancel_actor() await portal.cancel_actor()
if result_before_exit: if chk_ctx_result_before_exit:
await check_canceller(ctx) await check_canceller(ctx)
if not result_before_exit: if not chk_ctx_result_before_exit:
await check_canceller(ctx) await check_canceller(ctx)
if cancel_method != 'portal': if cancel_method != 'portal':
@ -703,46 +721,71 @@ async def echo_back_sequence(
ctx: tractor.Context, ctx: tractor.Context,
seq: list[int], seq: list[int],
wait_for_cancel: bool,
msg_buffer_size: int | None = None, msg_buffer_size: int | None = None,
) -> None: ) -> None:
''' '''
Send endlessly on the calleee stream. Send endlessly on the calleee stream using a small buffer size
setting on the contex to simulate backlogging that would normally
cause overruns.
''' '''
# NOTE: ensure that if the caller is expecting to cancel this task
# that we stay echoing much longer then they are so we don't
# return early instead of receive the cancel msg.
total_batches: int = 1000 if wait_for_cancel else 6
await ctx.started() await ctx.started()
async with ctx.open_stream( async with ctx.open_stream(
msg_buffer_size=msg_buffer_size, msg_buffer_size=msg_buffer_size,
backpressure=True, allow_overruns=True,
) as stream: ) as stream:
seq = list(seq) # bleh, `msgpack`... seq = list(seq) # bleh, `msgpack`...
count = 0 for _ in range(total_batches):
# while count < 10:
while True:
batch = [] batch = []
async for msg in stream: async for msg in stream:
batch.append(msg) batch.append(msg)
if batch == seq: if batch == seq:
break break
print('callee waiting on next')
for msg in batch: for msg in batch:
print(f'callee sending {msg}') print(f'callee sending {msg}')
await stream.send(msg) await stream.send(msg)
count += 1 print(
'EXITING CALLEEE:\n'
print("EXITING CALLEEE") f'{ctx.cancel_called_remote}'
)
return 'yo' return 'yo'
def test_stream_backpressure( @pytest.mark.parametrize(
'cancel_ctx',
[True, False],
ids=lambda item: f'cancel_ctx={item}'
)
def test_allow_overruns_stream(
cancel_ctx: bool,
loglevel: str, loglevel: str,
): ):
''' '''
Demonstrate small overruns of each task back and forth Demonstrate small overruns of each task back and forth
on a stream not raising any errors by default by setting on a stream not raising any errors by default by setting
the ``backpressure=True``. the ``allow_overruns=True``.
The original idea here was to show that if you set the feeder mem
chan to a size smaller then the # of msgs sent you could could not
get a `StreamOverrun` crash plus maybe get all the msgs that were
sent. The problem with the "real backpressure" case is that due to
the current arch it can result in the msg loop being blocked and thus
blocking cancellation - which is like super bad. So instead this test
had to be adjusted to more or less just "not send overrun errors" so
as to handle the case where the sender just moreso cares about not getting
errored out when it send to fast..
''' '''
async def main(): async def main():
@ -756,41 +799,44 @@ def test_stream_backpressure(
async with portal.open_context( async with portal.open_context(
echo_back_sequence, echo_back_sequence,
seq=seq, seq=seq,
wait_for_cancel=cancel_ctx,
) as (ctx, sent): ) as (ctx, sent):
assert sent is None assert sent is None
async with ctx.open_stream( async with ctx.open_stream(
msg_buffer_size=1, msg_buffer_size=1,
backpressure=True, allow_overruns=True,
# allow_overruns=True,
) as stream: ) as stream:
count = 0 count = 0
while count < 3: while count < 3:
for msg in seq: for msg in seq:
print(f'caller sending {msg}') print(f'root tx {msg}')
await stream.send(msg) await stream.send(msg)
await trio.sleep(0.1) await trio.sleep(0.1)
batch = [] batch = []
# with trio.move_on_after(1) as cs:
async for msg in stream: async for msg in stream:
print(f'RX {msg}') print(f'root rx {msg}')
batch.append(msg) batch.append(msg)
if batch == seq: if batch == seq:
break break
count += 1 count += 1
# if cs.cancelled_caught: if cancel_ctx:
# break
# cancel the remote task # cancel the remote task
# print('SENDING ROOT SIDE CANCEL') print('sending root side cancel')
# await ctx.cancel() await ctx.cancel()
# here the context should return
res = await ctx.result() res = await ctx.result()
if cancel_ctx:
assert isinstance(res, ContextCancelled)
assert tuple(res.canceller) == tractor.current_actor().uid
else:
print(f'RX ROOT SIDE RESULT {res}')
assert res == 'yo' assert res == 'yo'
# cancel the daemon # cancel the daemon