forked from goodboy/tractor
				
			Use `MsgStream.subscribe()` in `Context.result()`
The case exists where there is multiple tasks consuming from an open 2-way stream created via `Context.open_stream()` where a sibling task is pulling from the stream while some other task also calls `.result()`. Previously the `.result()` call would consume (drain) stream messages directly from the underlying mem chan which would mean any sibling task would not receive those same messages. Instead, make `.result()` check if a stream is open and instead consume (and discard) stream msgs using a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all interested tasks get copies of the same packets.ctx_result_consumption
							parent
							
								
									f7a1f3832f
								
							
						
					
					
						commit
						c9eb466d76
					
				|  | @ -27,7 +27,8 @@ from typing import ( | ||||||
|     Optional, |     Optional, | ||||||
|     Callable, |     Callable, | ||||||
|     AsyncGenerator, |     AsyncGenerator, | ||||||
|     AsyncIterator |     AsyncIterator, | ||||||
|  |     TYPE_CHECKING, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| import warnings | import warnings | ||||||
|  | @ -41,6 +42,10 @@ from .log import get_logger | ||||||
| from .trionics import broadcast_receiver, BroadcastReceiver | from .trionics import broadcast_receiver, BroadcastReceiver | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from ._portal import Portal | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| log = get_logger(__name__) | log = get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -378,7 +383,8 @@ class Context: | ||||||
|     _remote_func_type: Optional[str] = None |     _remote_func_type: Optional[str] = None | ||||||
| 
 | 
 | ||||||
|     # only set on the caller side |     # only set on the caller side | ||||||
|     _portal: Optional['Portal'] = None    # type: ignore # noqa |     _portal: Optional[Portal] = None    # type: ignore # noqa | ||||||
|  |     _stream: Optional[MsgStream] = None | ||||||
|     _result: Optional[Any] = False |     _result: Optional[Any] = False | ||||||
|     _error: Optional[BaseException] = None |     _error: Optional[BaseException] = None | ||||||
| 
 | 
 | ||||||
|  | @ -486,6 +492,7 @@ class Context: | ||||||
|         log.cancel(f'Cancelling {side} side of context to {self.chan.uid}') |         log.cancel(f'Cancelling {side} side of context to {self.chan.uid}') | ||||||
| 
 | 
 | ||||||
|         self._cancel_called = True |         self._cancel_called = True | ||||||
|  |         ipc_broken: bool = False | ||||||
| 
 | 
 | ||||||
|         if side == 'caller': |         if side == 'caller': | ||||||
|             if not self._portal: |             if not self._portal: | ||||||
|  | @ -503,7 +510,14 @@ class Context: | ||||||
|                 # NOTE: we're telling the far end actor to cancel a task |                 # NOTE: we're telling the far end actor to cancel a task | ||||||
|                 # corresponding to *this actor*. The far end local channel |                 # corresponding to *this actor*. The far end local channel | ||||||
|                 # instance is passed to `Actor._cancel_task()` implicitly. |                 # instance is passed to `Actor._cancel_task()` implicitly. | ||||||
|                 await self._portal.run_from_ns('self', '_cancel_task', cid=cid) |                 try: | ||||||
|  |                     await self._portal.run_from_ns( | ||||||
|  |                         'self', | ||||||
|  |                         '_cancel_task', | ||||||
|  |                         cid=cid, | ||||||
|  |                     ) | ||||||
|  |                 except trio.BrokenResourceError: | ||||||
|  |                     ipc_broken = True | ||||||
| 
 | 
 | ||||||
|             if cs.cancelled_caught: |             if cs.cancelled_caught: | ||||||
|                 # XXX: there's no way to know if the remote task was indeed |                 # XXX: there's no way to know if the remote task was indeed | ||||||
|  | @ -519,7 +533,10 @@ class Context: | ||||||
|                         "Timed out on cancelling remote task " |                         "Timed out on cancelling remote task " | ||||||
|                         f"{cid} for {self._portal.channel.uid}") |                         f"{cid} for {self._portal.channel.uid}") | ||||||
| 
 | 
 | ||||||
|         # callee side remote task |             elif ipc_broken: | ||||||
|  |                 log.cancel( | ||||||
|  |                     "Transport layer was broken before cancel request " | ||||||
|  |                     f"{cid} for {self._portal.channel.uid}") | ||||||
|         else: |         else: | ||||||
|             self._cancel_msg = msg |             self._cancel_msg = msg | ||||||
| 
 | 
 | ||||||
|  | @ -607,6 +624,7 @@ class Context: | ||||||
|             ctx=self, |             ctx=self, | ||||||
|             rx_chan=ctx._recv_chan, |             rx_chan=ctx._recv_chan, | ||||||
|         ) as stream: |         ) as stream: | ||||||
|  |             self._stream = stream | ||||||
| 
 | 
 | ||||||
|             if self._portal: |             if self._portal: | ||||||
|                 self._portal._streams.add(stream) |                 self._portal._streams.add(stream) | ||||||
|  | @ -648,25 +666,22 @@ class Context: | ||||||
| 
 | 
 | ||||||
|             if not self._recv_chan._closed:  # type: ignore |             if not self._recv_chan._closed:  # type: ignore | ||||||
| 
 | 
 | ||||||
|                 # wait for a final context result consuming |                 def consume( | ||||||
|                 # and discarding any bi dir stream msgs still |                     msg: dict, | ||||||
|                 # in transit from the far end. |  | ||||||
|                 while True: |  | ||||||
| 
 | 
 | ||||||
|                     msg = await self._recv_chan.receive() |                 ) -> Optional[dict]: | ||||||
|                     try: |                     try: | ||||||
|                         self._result = msg['return'] |                         return msg['return'] | ||||||
|                         break |  | ||||||
|                     except KeyError as msgerr: |                     except KeyError as msgerr: | ||||||
| 
 | 
 | ||||||
|                         if 'yield' in msg: |                         if 'yield' in msg: | ||||||
|                             # far end task is still streaming to us so discard |                             # far end task is still streaming to us so discard | ||||||
|                             log.warning(f'Discarding stream delivered {msg}') |                             log.warning(f'Discarding stream delivered {msg}') | ||||||
|                             continue |                             return | ||||||
| 
 | 
 | ||||||
|                         elif 'stop' in msg: |                         elif 'stop' in msg: | ||||||
|                             log.debug('Remote stream terminated') |                             log.debug('Remote stream terminated') | ||||||
|                             continue |                             return | ||||||
| 
 | 
 | ||||||
|                         # internal error should never get here |                         # internal error should never get here | ||||||
|                         assert msg.get('cid'), ( |                         assert msg.get('cid'), ( | ||||||
|  | @ -676,6 +691,24 @@ class Context: | ||||||
|                             msg, self._portal.channel |                             msg, self._portal.channel | ||||||
|                         ) from msgerr |                         ) from msgerr | ||||||
| 
 | 
 | ||||||
|  |                 # wait for a final context result consuming | ||||||
|  |                 # and discarding any bi dir stream msgs still | ||||||
|  |                 # in transit from the far end. | ||||||
|  |                 if self._stream: | ||||||
|  |                     async with self._stream.subscribe() as bstream: | ||||||
|  |                         async for msg in bstream: | ||||||
|  |                             result = consume(msg) | ||||||
|  |                             if result: | ||||||
|  |                                 self._result = result | ||||||
|  | 
 | ||||||
|  |                 if not self._result: | ||||||
|  |                     while True: | ||||||
|  |                         msg = await self._recv_chan.receive() | ||||||
|  |                         result = consume(msg) | ||||||
|  |                         if result: | ||||||
|  |                             self._result = result | ||||||
|  |                             break | ||||||
|  | 
 | ||||||
|         return self._result |         return self._result | ||||||
| 
 | 
 | ||||||
|     async def started( |     async def started( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue