Allow shielding in `open_portal()`

dereg_on_channel_aclose
Tyler Goodlet 2020-08-08 14:47:52 -04:00
parent 532429aec9
commit 90c7fa6963
1 changed files with 7 additions and 2 deletions

View File

@ -22,7 +22,8 @@ log = get_logger('tractor')
@asynccontextmanager @asynccontextmanager
async def maybe_open_nursery( async def maybe_open_nursery(
nursery: trio.Nursery = None nursery: trio.Nursery = None,
shield: bool = False,
) -> typing.AsyncGenerator[trio.Nursery, Any]: ) -> typing.AsyncGenerator[trio.Nursery, Any]:
"""Create a new nursery if None provided. """Create a new nursery if None provided.
@ -32,6 +33,7 @@ async def maybe_open_nursery(
yield nursery yield nursery
else: else:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
nursery.cancel_scope.shield = shield
yield nursery yield nursery
@ -275,6 +277,8 @@ class Portal:
f"{self.channel}") f"{self.channel}")
try: try:
# send cancel cmd - might not get response # send cancel cmd - might not get response
# XXX: sure would be nice to make this work with a proper shield
# with trio.CancelScope(shield=True):
with trio.move_on_after(0.5) as cancel_scope: with trio.move_on_after(0.5) as cancel_scope:
cancel_scope.shield = True cancel_scope.shield = True
await self.run('self', 'cancel') await self.run('self', 'cancel')
@ -316,6 +320,7 @@ async def open_portal(
channel: Channel, channel: Channel,
nursery: Optional[trio.Nursery] = None, nursery: Optional[trio.Nursery] = None,
start_msg_loop: bool = True, start_msg_loop: bool = True,
shield: bool = False,
) -> typing.AsyncGenerator[Portal, None]: ) -> typing.AsyncGenerator[Portal, None]:
"""Open a ``Portal`` through the provided ``channel``. """Open a ``Portal`` through the provided ``channel``.
@ -325,7 +330,7 @@ async def open_portal(
assert actor assert actor
was_connected = False was_connected = False
async with maybe_open_nursery(nursery) as nursery: async with maybe_open_nursery(nursery, shield=shield) as nursery:
if not channel.connected(): if not channel.connected():
await channel.connect() await channel.connect()
was_connected = True was_connected = True