From 5e31806770d0f92a0aaee287ceb3d4b912e47b67 Mon Sep 17 00:00:00 2001
From: Tyler Goodlet <jgbt@protonmail.com>
Date: Fri, 19 May 2023 14:03:07 -0400
Subject: [PATCH] Ensure user-allocated cancel scope just works!

Turns out the nursery doesn't have to care about allocating a per task
`CancelScope` since the user can just do that in the
`@task_scope_manager` if desired B) So just mask all the nursery cs
allocating with the intention of removal.

Also add a test for per-task-cancellation by starting the crash task as
a `trio.sleep_forever()` but then cancel it via the user allocated cs
and ensure the crash propagates as expected :boom:
---
 tractor/trionics/_supervisor.py | 134 ++++++++++++++++++++------------
 1 file changed, 83 insertions(+), 51 deletions(-)

diff --git a/tractor/trionics/_supervisor.py b/tractor/trionics/_supervisor.py
index dcd20502..d23e1df8 100644
--- a/tractor/trionics/_supervisor.py
+++ b/tractor/trionics/_supervisor.py
@@ -126,18 +126,23 @@ class ScopePerTaskNursery(Struct):
         # task = new_tasks.pop()
 
         n: Nursery = self._n
-        cs = CancelScope()
+
+        sm = self.scope_manager
+        # we do default behavior of a scope-per-nursery
+        # if the user did not provide a task manager.
+        if sm is None:
+            return n.start_soon(async_fn, *args, name=None)
+
+        # per_task_cs = CancelScope()
         new_task: Task | None = None
         to_return: tuple[Any] | None = None
 
-        sm = self.scope_manager
-        if sm is None:
-            mngr = nullcontext([cs])
-        else:
-            # NOTE: what do we enforce as a signature for the
-            # `@task_scope_manager` here?
-            mngr = sm(nursery=n)
-
+        # NOTE: what do we enforce as a signature for the
+        # `@task_scope_manager` here?
+        mngr = sm(
+            nursery=n,
+            # scope=per_task_cs,
+        )
         async def _start_wrapped_in_scope(
             task_status: TaskStatus[
                 tuple[CancelScope, Task]
@@ -148,48 +153,49 @@ class ScopePerTaskNursery(Struct):
             # TODO: this was working before?!
             # nonlocal to_return
 
-            with cs:
+            task = trio.lowlevel.current_task()
+            # self._scopes[per_task_cs] = task
 
-                task = trio.lowlevel.current_task()
-                self._scopes[cs] = task
+            # NOTE: we actually don't need this since the user can
+            # just to it themselves inside mngr!
+            # with per_task_cs:
 
-                # execute up to the first yield
-                try:
-                    to_return: tuple[Any] = next(mngr)
-                except StopIteration:
-                    raise RuntimeError("task manager didn't yield") from None
+            # execute up to the first yield
+            try:
+                to_return: tuple[Any] = next(mngr)
+            except StopIteration:
+                raise RuntimeError("task manager didn't yield") from None
 
-                # TODO: how do we support `.start()` style?
-                # - relay through whatever the
-                #   started task passes back via `.started()` ?
-                #   seems like that won't work with also returning
-                #   a "task handle"?
-                # - we were previously binding-out this `to_return` to
-                #   the parent's lexical scope, why isn't that working
-                #   now?
-                task_status.started(to_return)
+            # TODO: how do we support `.start()` style?
+            # - relay through whatever the
+            #   started task passes back via `.started()` ?
+            #   seems like that won't work with also returning
+            #   a "task handle"?
+            # - we were previously binding-out this `to_return` to
+            #   the parent's lexical scope, why isn't that working
+            #   now?
+            task_status.started(to_return)
 
-                # invoke underlying func now that cs is entered.
-                outcome = await acapture(async_fn, *args)
+            # invoke underlying func now that cs is entered.
+            outcome = await acapture(async_fn, *args)
 
-                # execute from the 1st yield to return and expect
-                # generator-mngr `@task_scope_manager` thinger to
-                # terminate!
-                try:
-                    mngr.send(outcome)
+            # execute from the 1st yield to return and expect
+            # generator-mngr `@task_scope_manager` thinger to
+            # terminate!
+            try:
+                mngr.send(outcome)
 
-                    # NOTE: this will instead send the underlying
-                    # `.value`? Not sure if that's better or not?
-                    # I would presume it's better to have a handle to
-                    # the `Outcome` entirely? This method sends *into*
-                    # the mngr this `Outcome.value`; seems like kinda
-                    # weird semantics for our purposes?
-                    # outcome.send(mngr)
 
-                except StopIteration:
-                    return
-                else:
-                    raise RuntimeError(f"{mngr} didn't stop!")
+                # I would presume it's better to have a handle to
+                # the `Outcome` entirely? This method sends *into*
+                # the mngr this `Outcome.value`; seems like kinda
+                # weird semantics for our purposes?
+                # outcome.send(mngr)
+
+            except StopIteration:
+                return
+            else:
+                raise RuntimeError(f"{mngr} didn't stop!")
 
         to_return = await n.start(_start_wrapped_in_scope)
         assert to_return is not None
@@ -200,7 +206,6 @@ class ScopePerTaskNursery(Struct):
         return to_return
 
 
-
 # TODO: you could wrap your output task handle in this?
 # class TaskHandle(Struct):
 #     task: Task
@@ -214,6 +219,11 @@ class ScopePerTaskNursery(Struct):
 def add_task_handle_and_crash_handling(
     nursery: Nursery,
 
+    # TODO: is this the only way we can have a per-task scope
+    # allocated or can we allow the user to somehow do it if
+    # they want below?
+    # scope: CancelScope,
+
 ) -> Generator[None, list[Any]]:
 
     task_outcome = TaskOutcome()
@@ -222,8 +232,12 @@ def add_task_handle_and_crash_handling(
     task: Task = trio.lowlevel.current_task()
     print(f'Spawning task: {task.name}')
 
+    # yields back when task is terminated, cancelled, returns.
     try:
-        # yields back when task is terminated, cancelled, returns?
+        # XXX: wait, this isn't doing anything right since we'd have to
+        # manually activate this scope using something like:
+        # `task._activate_cancel_status(cs._cancel_status)` ??
+        # oh wait, but `.__enter__()` does all that already?
         with CancelScope() as cs:
 
             # the yielded value(s) here are what are returned to the
@@ -260,6 +274,19 @@ async def sleep_then_return_val(val: str):
     return val
 
 
+async def ensure_cancelled():
+    try:
+        await trio.sleep_forever()
+
+    except trio.Cancelled:
+        task = trio.lowlevel.current_task()
+        print(f'heyyo ONLY {task.name} was cancelled as expected B)')
+        assert 0
+
+    except BaseException:
+        raise RuntimeError("woa woa woa this ain't right!")
+
+
 if __name__ == '__main__':
 
     async def main():
@@ -267,17 +294,22 @@ if __name__ == '__main__':
             scope_manager=add_task_handle_and_crash_handling,
         ) as sn:
             for _ in range(3):
-                outcome, cs = await sn.start_soon(trio.sleep_forever)
+                outcome, _ = await sn.start_soon(trio.sleep_forever)
 
             # extra task we want to engage in debugger post mortem.
-            err_outcome, *_ = await sn.start_soon(sleep_then_err)
+            err_outcome, cs = await sn.start_soon(ensure_cancelled)
 
             val: str = 'yoyoyo'
-            val_outcome, cs = await sn.start_soon(sleep_then_return_val, val)
+            val_outcome, _ = await sn.start_soon(
+                sleep_then_return_val,
+                val,
+            )
             res = await val_outcome.wait_for_result()
             assert res == val
-            print(f'GOT EXPECTED TASK VALUE: {res}')
+            print(f'{res} -> GOT EXPECTED TASK VALUE')
 
-            print('WAITING FOR CRASH..')
+            await trio.sleep(0.6)
+            print('Cancelling and waiting for CRASH..')
+            cs.cancel()
 
     trio.run(main)