Use `MsgStream.subscribe()` in `Context.result()`

The case exists where there is multiple tasks consuming from an open
2-way stream created via `Context.open_stream()` where a sibling task is
pulling from the stream while some other task also calls `.result()`.
Previously the `.result()` call would consume (drain) stream messages
directly from the underlying mem chan which would mean any sibling task
would not receive those same messages. Instead, make `.result()` check
if a stream is open and instead consume (and discard) stream msgs using
a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all
interested tasks get copies of the same packets.
egs_with_ctx_res_consumption
Tyler Goodlet 2022-03-14 07:08:22 -04:00
parent 4a5f041211
commit 3483151aa8
1 changed files with 49 additions and 16 deletions

View File

@ -27,7 +27,8 @@ from typing import (
Optional,
Callable,
AsyncGenerator,
AsyncIterator
AsyncIterator,
TYPE_CHECKING,
)
import warnings
@ -41,6 +42,10 @@ from .log import get_logger
from .trionics import broadcast_receiver, BroadcastReceiver
if TYPE_CHECKING:
from ._portal import Portal
log = get_logger(__name__)
@ -365,7 +370,8 @@ class Context:
_remote_func_type: Optional[str] = None
# only set on the caller side
_portal: Optional['Portal'] = None # type: ignore # noqa
_portal: Optional[Portal] = None # type: ignore # noqa
_stream: Optional[MsgStream] = None
_result: Optional[Any] = False
_error: Optional[BaseException] = None
@ -473,6 +479,7 @@ class Context:
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')
self._cancel_called = True
ipc_broken: bool = False
if side == 'caller':
if not self._portal:
@ -490,7 +497,14 @@ class Context:
# NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly.
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
try:
await self._portal.run_from_ns(
'self',
'_cancel_task',
cid=cid,
)
except trio.BrokenResourceError:
ipc_broken = True
if cs.cancelled_caught:
# XXX: there's no way to know if the remote task was indeed
@ -506,7 +520,10 @@ class Context:
"Timed out on cancelling remote task "
f"{cid} for {self._portal.channel.uid}")
# callee side remote task
elif ipc_broken:
log.cancel(
"Transport layer was broken before cancel request "
f"{cid} for {self._portal.channel.uid}")
else:
self._cancel_msg = msg
@ -593,10 +610,11 @@ class Context:
async with MsgStream(
ctx=self,
rx_chan=ctx._recv_chan,
) as rchan:
) as stream:
self._stream = stream
if self._portal:
self._portal._streams.add(rchan)
self._portal._streams.add(stream)
try:
self._stream_opened = True
@ -604,7 +622,7 @@ class Context:
# ensure we aren't cancelled before delivering
# the stream
# await trio.lowlevel.checkpoint()
yield rchan
yield stream
# XXX: Make the stream "one-shot use". On exit, signal
# ``trio.EndOfChannel``/``StopAsyncIteration`` to the
@ -635,25 +653,22 @@ class Context:
if not self._recv_chan._closed: # type: ignore
# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
while True:
def consume(
msg: dict,
msg = await self._recv_chan.receive()
) -> Optional[dict]:
try:
self._result = msg['return']
break
return msg['return']
except KeyError as msgerr:
if 'yield' in msg:
# far end task is still streaming to us so discard
log.warning(f'Discarding stream delivered {msg}')
continue
return
elif 'stop' in msg:
log.debug('Remote stream terminated')
continue
return
# internal error should never get here
assert msg.get('cid'), (
@ -663,6 +678,24 @@ class Context:
msg, self._portal.channel
) from msgerr
# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
if self._stream:
async with self._stream.subscribe() as bstream:
async for msg in bstream:
result = consume(msg)
if result:
self._result = result
if not self._result:
while True:
msg = await self._recv_chan.receive()
result = consume(msg)
if result:
self._result = result
break
return self._result
async def started(