Use `MsgStream.subscribe()` in `Context.result()`
The case exists where there is multiple tasks consuming from an open 2-way stream created via `Context.open_stream()` where a sibling task is pulling from the stream while some other task also calls `.result()`. Previously the `.result()` call would consume (drain) stream messages directly from the underlying mem chan which would mean any sibling task would not receive those same messages. Instead, make `.result()` check if a stream is open and instead consume (and discard) stream msgs using a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all interested tasks get copies of the same packets.ctx_result_consumption
parent
f7a1f3832f
commit
c9eb466d76
|
@ -27,7 +27,8 @@ from typing import (
|
|||
Optional,
|
||||
Callable,
|
||||
AsyncGenerator,
|
||||
AsyncIterator
|
||||
AsyncIterator,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import warnings
|
||||
|
@ -41,6 +42,10 @@ from .log import get_logger
|
|||
from .trionics import broadcast_receiver, BroadcastReceiver
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._portal import Portal
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -378,7 +383,8 @@ class Context:
|
|||
_remote_func_type: Optional[str] = None
|
||||
|
||||
# only set on the caller side
|
||||
_portal: Optional['Portal'] = None # type: ignore # noqa
|
||||
_portal: Optional[Portal] = None # type: ignore # noqa
|
||||
_stream: Optional[MsgStream] = None
|
||||
_result: Optional[Any] = False
|
||||
_error: Optional[BaseException] = None
|
||||
|
||||
|
@ -486,6 +492,7 @@ class Context:
|
|||
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')
|
||||
|
||||
self._cancel_called = True
|
||||
ipc_broken: bool = False
|
||||
|
||||
if side == 'caller':
|
||||
if not self._portal:
|
||||
|
@ -503,7 +510,14 @@ class Context:
|
|||
# NOTE: we're telling the far end actor to cancel a task
|
||||
# corresponding to *this actor*. The far end local channel
|
||||
# instance is passed to `Actor._cancel_task()` implicitly.
|
||||
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
|
||||
try:
|
||||
await self._portal.run_from_ns(
|
||||
'self',
|
||||
'_cancel_task',
|
||||
cid=cid,
|
||||
)
|
||||
except trio.BrokenResourceError:
|
||||
ipc_broken = True
|
||||
|
||||
if cs.cancelled_caught:
|
||||
# XXX: there's no way to know if the remote task was indeed
|
||||
|
@ -519,7 +533,10 @@ class Context:
|
|||
"Timed out on cancelling remote task "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
|
||||
# callee side remote task
|
||||
elif ipc_broken:
|
||||
log.cancel(
|
||||
"Transport layer was broken before cancel request "
|
||||
f"{cid} for {self._portal.channel.uid}")
|
||||
else:
|
||||
self._cancel_msg = msg
|
||||
|
||||
|
@ -607,6 +624,7 @@ class Context:
|
|||
ctx=self,
|
||||
rx_chan=ctx._recv_chan,
|
||||
) as stream:
|
||||
self._stream = stream
|
||||
|
||||
if self._portal:
|
||||
self._portal._streams.add(stream)
|
||||
|
@ -648,25 +666,22 @@ class Context:
|
|||
|
||||
if not self._recv_chan._closed: # type: ignore
|
||||
|
||||
# wait for a final context result consuming
|
||||
# and discarding any bi dir stream msgs still
|
||||
# in transit from the far end.
|
||||
while True:
|
||||
def consume(
|
||||
msg: dict,
|
||||
|
||||
msg = await self._recv_chan.receive()
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
self._result = msg['return']
|
||||
break
|
||||
return msg['return']
|
||||
except KeyError as msgerr:
|
||||
|
||||
if 'yield' in msg:
|
||||
# far end task is still streaming to us so discard
|
||||
log.warning(f'Discarding stream delivered {msg}')
|
||||
continue
|
||||
return
|
||||
|
||||
elif 'stop' in msg:
|
||||
log.debug('Remote stream terminated')
|
||||
continue
|
||||
return
|
||||
|
||||
# internal error should never get here
|
||||
assert msg.get('cid'), (
|
||||
|
@ -676,6 +691,24 @@ class Context:
|
|||
msg, self._portal.channel
|
||||
) from msgerr
|
||||
|
||||
# wait for a final context result consuming
|
||||
# and discarding any bi dir stream msgs still
|
||||
# in transit from the far end.
|
||||
if self._stream:
|
||||
async with self._stream.subscribe() as bstream:
|
||||
async for msg in bstream:
|
||||
result = consume(msg)
|
||||
if result:
|
||||
self._result = result
|
||||
|
||||
if not self._result:
|
||||
while True:
|
||||
msg = await self._recv_chan.receive()
|
||||
result = consume(msg)
|
||||
if result:
|
||||
self._result = result
|
||||
break
|
||||
|
||||
return self._result
|
||||
|
||||
async def started(
|
||||
|
|
Loading…
Reference in New Issue