Raise any asyncio errors if in trio task on cancel

pre_bad_close
Tyler Goodlet 2020-12-10 13:48:40 -05:00
parent 5aa5c4a253
commit bdb4b3a050
1 changed files with 14 additions and 2 deletions

View File

@ -8,7 +8,6 @@ from typing import (
Callable, Callable,
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Union,
) )
import trio import trio
@ -91,10 +90,12 @@ async def run_task(
def cancel_trio(task): def cancel_trio(task):
"""Cancel the calling ``trio`` task on error. """Cancel the calling ``trio`` task on error.
""" """
nonlocal err nonlocal aio_err
aio_err = task.exception() aio_err = task.exception()
if aio_err: if aio_err:
log.exception(f"asyncio task errorred:\n{aio_err}") log.exception(f"asyncio task errorred:\n{aio_err}")
cancel_scope.cancel() cancel_scope.cancel()
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
@ -109,6 +110,12 @@ async def run_task(
async with from_aio: async with from_aio:
async for item in from_aio: async for item in from_aio:
yield item yield item
if cancel_scope.cancelled_caught:
# always raise from any captured asyncio error
if aio_err:
raise aio_err
except BaseException as err: except BaseException as err:
if aio_err is not None: if aio_err is not None:
# always raise from any captured asyncio error # always raise from any captured asyncio error
@ -124,6 +131,11 @@ async def run_task(
# return single value # return single value
return await from_aio.receive() return await from_aio.receive()
if cancel_scope.cancelled_caught:
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# Do we need this? # Do we need this?
except BaseException as err: except BaseException as err:
if aio_err is not None: if aio_err is not None: