forked from goodboy/tractor
				
			Mark stream with EOC when stop message is received
							parent
							
								
									79d63585b0
								
							
						
					
					
						commit
						f2ba961e81
					
				| 
						 | 
				
			
			@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
 | 
			
		||||
        # flag to denote end of stream
 | 
			
		||||
        self._eoc: bool = False
 | 
			
		||||
        self._closed: bool = False
 | 
			
		||||
 | 
			
		||||
    # delegate directly to underlying mem channel
 | 
			
		||||
    def receive_nowait(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
            msg = await self._rx_chan.receive()
 | 
			
		||||
            return msg['yield']
 | 
			
		||||
 | 
			
		||||
        except KeyError:
 | 
			
		||||
        except KeyError as err:
 | 
			
		||||
            # internal error should never get here
 | 
			
		||||
            assert msg.get('cid'), ("Received internal error at portal?")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
            # - 'error'
 | 
			
		||||
            # possibly just handle msg['stop'] here!
 | 
			
		||||
 | 
			
		||||
            if msg.get('stop'):
 | 
			
		||||
            if msg.get('stop') or self._eoc:
 | 
			
		||||
                log.debug(f"{self} was stopped at remote end")
 | 
			
		||||
 | 
			
		||||
                # XXX: important to set so that a new ``.receive()``
 | 
			
		||||
                # call (likely by another task using a broadcast receiver)
 | 
			
		||||
                # doesn't accidentally pull the ``return`` message
 | 
			
		||||
                # value out of the underlying feed mem chan!
 | 
			
		||||
                self._eoc = True
 | 
			
		||||
 | 
			
		||||
                # # when the send is closed we assume the stream has
 | 
			
		||||
                # # terminated and signal this local iterator to stop
 | 
			
		||||
                # await self.aclose()
 | 
			
		||||
| 
						 | 
				
			
			@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
                # XXX: this causes ``ReceiveChannel.__anext__()`` to
 | 
			
		||||
                # raise a ``StopAsyncIteration`` **and** in our catch
 | 
			
		||||
                # block below it will trigger ``.aclose()``.
 | 
			
		||||
                raise trio.EndOfChannel
 | 
			
		||||
                raise trio.EndOfChannel from err
 | 
			
		||||
 | 
			
		||||
            # TODO: test that shows stream raising an expected error!!!
 | 
			
		||||
            elif msg.get('error'):
 | 
			
		||||
| 
						 | 
				
			
			@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
            raise  # propagate
 | 
			
		||||
 | 
			
		||||
    async def aclose(self):
 | 
			
		||||
        """Cancel associated remote actor task and local memory channel
 | 
			
		||||
        on close.
 | 
			
		||||
        '''
 | 
			
		||||
        Cancel associated remote actor task and local memory channel on
 | 
			
		||||
        close.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        '''
 | 
			
		||||
        # XXX: keep proper adherance to trio's `.aclose()` semantics:
 | 
			
		||||
        # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
 | 
			
		||||
        rx_chan = self._rx_chan
 | 
			
		||||
| 
						 | 
				
			
			@ -178,7 +186,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
            # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        self._eoc = True
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
        # NOTE: this is super subtle IPC messaging stuff:
 | 
			
		||||
        # Relay stop iteration to far end **iff** we're
 | 
			
		||||
| 
						 | 
				
			
			@ -310,15 +318,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
 | 
			
		|||
        self,
 | 
			
		||||
        data: Any
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        '''Send a message over this stream to the far end.
 | 
			
		||||
        '''
 | 
			
		||||
        Send a message over this stream to the far end.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # if self._eoc:
 | 
			
		||||
        #     raise trio.ClosedResourceError('This stream is already ded')
 | 
			
		||||
 | 
			
		||||
        if self._ctx._error:
 | 
			
		||||
            raise self._ctx._error  # from None
 | 
			
		||||
 | 
			
		||||
        if self._closed:
 | 
			
		||||
            raise trio.ClosedResourceError('This stream was already closed')
 | 
			
		||||
 | 
			
		||||
        await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue