forked from goodboy/tractor
1
0
Fork 0

Ensure user-allocated cancel scope just works!

Turns out the nursery doesn't have to care about allocating a per task
`CancelScope` since the user can just do that in the
`@task_scope_manager` if desired B) So just mask all the nursery cs
allocating with the intention of removal.

Also add a test for per-task-cancellation by starting the crash task as
a `trio.sleep_forever()` but then cancel it via the user allocated cs
and ensure the crash propagates as expected 💥
oco_supervisor_prototype
Tyler Goodlet 2023-05-19 14:03:07 -04:00
parent f23b5b89dd
commit 56882b680c
1 changed files with 83 additions and 51 deletions

View File

@ -126,18 +126,23 @@ class ScopePerTaskNursery(Struct):
# task = new_tasks.pop()
n: Nursery = self._n
cs = CancelScope()
sm = self.scope_manager
# we do default behavior of a scope-per-nursery
# if the user did not provide a task manager.
if sm is None:
return n.start_soon(async_fn, *args, name=None)
# per_task_cs = CancelScope()
new_task: Task | None = None
to_return: tuple[Any] | None = None
sm = self.scope_manager
if sm is None:
mngr = nullcontext([cs])
else:
# NOTE: what do we enforce as a signature for the
# `@task_scope_manager` here?
mngr = sm(nursery=n)
mngr = sm(
nursery=n,
# scope=per_task_cs,
)
async def _start_wrapped_in_scope(
task_status: TaskStatus[
tuple[CancelScope, Task]
@ -148,10 +153,12 @@ class ScopePerTaskNursery(Struct):
# TODO: this was working before?!
# nonlocal to_return
with cs:
task = trio.lowlevel.current_task()
self._scopes[cs] = task
# self._scopes[per_task_cs] = task
# NOTE: we actually don't need this since the user can
# just to it themselves inside mngr!
# with per_task_cs:
# execute up to the first yield
try:
@ -178,8 +185,7 @@ class ScopePerTaskNursery(Struct):
try:
mngr.send(outcome)
# NOTE: this will instead send the underlying
# `.value`? Not sure if that's better or not?
# I would presume it's better to have a handle to
# the `Outcome` entirely? This method sends *into*
# the mngr this `Outcome.value`; seems like kinda
@ -200,7 +206,6 @@ class ScopePerTaskNursery(Struct):
return to_return
# TODO: you could wrap your output task handle in this?
# class TaskHandle(Struct):
# task: Task
@ -214,6 +219,11 @@ class ScopePerTaskNursery(Struct):
def add_task_handle_and_crash_handling(
nursery: Nursery,
# TODO: is this the only way we can have a per-task scope
# allocated or can we allow the user to somehow do it if
# they want below?
# scope: CancelScope,
) -> Generator[None, list[Any]]:
task_outcome = TaskOutcome()
@ -222,8 +232,12 @@ def add_task_handle_and_crash_handling(
task: Task = trio.lowlevel.current_task()
print(f'Spawning task: {task.name}')
# yields back when task is terminated, cancelled, returns.
try:
# yields back when task is terminated, cancelled, returns?
# XXX: wait, this isn't doing anything right since we'd have to
# manually activate this scope using something like:
# `task._activate_cancel_status(cs._cancel_status)` ??
# oh wait, but `.__enter__()` does all that already?
with CancelScope() as cs:
# the yielded value(s) here are what are returned to the
@ -260,6 +274,19 @@ async def sleep_then_return_val(val: str):
return val
async def ensure_cancelled():
try:
await trio.sleep_forever()
except trio.Cancelled:
task = trio.lowlevel.current_task()
print(f'heyyo ONLY {task.name} was cancelled as expected B)')
assert 0
except BaseException:
raise RuntimeError("woa woa woa this ain't right!")
if __name__ == '__main__':
async def main():
@ -267,17 +294,22 @@ if __name__ == '__main__':
scope_manager=add_task_handle_and_crash_handling,
) as sn:
for _ in range(3):
outcome, cs = await sn.start_soon(trio.sleep_forever)
outcome, _ = await sn.start_soon(trio.sleep_forever)
# extra task we want to engage in debugger post mortem.
err_outcome, *_ = await sn.start_soon(sleep_then_err)
err_outcome, cs = await sn.start_soon(ensure_cancelled)
val: str = 'yoyoyo'
val_outcome, cs = await sn.start_soon(sleep_then_return_val, val)
val_outcome, _ = await sn.start_soon(
sleep_then_return_val,
val,
)
res = await val_outcome.wait_for_result()
assert res == val
print(f'GOT EXPECTED TASK VALUE: {res}')
print(f'{res} -> GOT EXPECTED TASK VALUE')
print('WAITING FOR CRASH..')
await trio.sleep(0.6)
print('Cancelling and waiting for CRASH..')
cs.cancel()
trio.run(main)