forked from goodboy/tractor
				
			Add our own "transport closed" signal
This change some super old (and bad) code from the project's very early days. For some redic reason i must have thought masking `trio`'s internal stream / transport errors and a TCP EOF as `StopAsyncIteration` somehow a good idea. The reality is you probably want to know the difference between an unexpected transport error and a simple EOF lol. This begins to resolve that by adding our own special `TransportClosed` error to signal the "graceful" termination of a channel's underlying transport. Oh, and this builds on the `msgspec` integration which helped shed light on the core issues here B)try_msgspec
							parent
							
								
									97f44e2e27
								
							
						
					
					
						commit
						1a068add5d
					
				|  | @ -41,6 +41,10 @@ class ContextCancelled(RemoteActorError): | ||||||
|     "Inter-actor task context cancelled itself on the callee side." |     "Inter-actor task context cancelled itself on the callee side." | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class TransportClosed(trio.ClosedResourceError): | ||||||
|  |     "Underlying channel transport was closed prior to use" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class NoResult(RuntimeError): | class NoResult(RuntimeError): | ||||||
|     "No final result is expected for this actor" |     "No final result is expected for this actor" | ||||||
| 
 | 
 | ||||||
|  | @ -66,12 +70,15 @@ def pack_error(exc: BaseException) -> Dict[str, Any]: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def unpack_error( | def unpack_error( | ||||||
|  | 
 | ||||||
|     msg: Dict[str, Any], |     msg: Dict[str, Any], | ||||||
|     chan=None, |     chan=None, | ||||||
|     err_type=RemoteActorError |     err_type=RemoteActorError | ||||||
|  | 
 | ||||||
| ) -> Exception: | ) -> Exception: | ||||||
|     """Unpack an 'error' message from the wire |     """Unpack an 'error' message from the wire | ||||||
|     into a local ``RemoteActorError``. |     into a local ``RemoteActorError``. | ||||||
|  | 
 | ||||||
|     """ |     """ | ||||||
|     error = msg['error'] |     error = msg['error'] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ | ||||||
| Inter-process comms abstractions | Inter-process comms abstractions | ||||||
| """ | """ | ||||||
| from functools import partial | from functools import partial | ||||||
|  | import math | ||||||
| import struct | import struct | ||||||
| import typing | import typing | ||||||
| from typing import Any, Tuple, Optional | from typing import Any, Tuple, Optional | ||||||
|  | @ -13,6 +14,7 @@ import trio | ||||||
| from async_generator import asynccontextmanager | from async_generator import asynccontextmanager | ||||||
| 
 | 
 | ||||||
| from .log import get_logger | from .log import get_logger | ||||||
|  | from ._exceptions import TransportClosed | ||||||
| log = get_logger(__name__) | log = get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| # :eyeroll: | # :eyeroll: | ||||||
|  | @ -24,7 +26,7 @@ except ImportError: | ||||||
|     Unpacker = partial(msgpack.Unpacker, strict_map_key=False) |     Unpacker = partial(msgpack.Unpacker, strict_map_key=False) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MsgpackStream: | class MsgpackTCPStream: | ||||||
|     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data |     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||||
|     using ``msgpack-python``. |     using ``msgpack-python``. | ||||||
| 
 | 
 | ||||||
|  | @ -47,7 +49,10 @@ class MsgpackStream: | ||||||
|         assert isinstance(rsockname, tuple) |         assert isinstance(rsockname, tuple) | ||||||
|         self._raddr = rsockname[:2] |         self._raddr = rsockname[:2] | ||||||
| 
 | 
 | ||||||
|  |         # start and seed first entry to read loop | ||||||
|         self._agen = self._iter_packets() |         self._agen = self._iter_packets() | ||||||
|  |         # self._agen.asend(None) is None | ||||||
|  | 
 | ||||||
|         self._send_lock = trio.StrictFIFOLock() |         self._send_lock = trio.StrictFIFOLock() | ||||||
| 
 | 
 | ||||||
|     async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: |     async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: | ||||||
|  | @ -58,16 +63,13 @@ class MsgpackStream: | ||||||
|             use_list=False, |             use_list=False, | ||||||
|         ) |         ) | ||||||
|         while True: |         while True: | ||||||
|             try: |             data = await self.stream.receive_some(2**10) | ||||||
|                 data = await self.stream.receive_some(2**10) |             log.trace(f"received {data}")  # type: ignore | ||||||
|                 log.trace(f"received {data}")  # type: ignore |  | ||||||
|             except trio.BrokenResourceError: |  | ||||||
|                 log.warning(f"Stream connection {self.raddr} broke") |  | ||||||
|                 return |  | ||||||
| 
 | 
 | ||||||
|             if data == b'': |             if data == b'': | ||||||
|                 log.debug(f"Stream connection {self.raddr} was closed") |                 raise TransportClosed( | ||||||
|                 return |                     f'transport {self} was already closed prior ro read' | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|             unpacker.feed(data) |             unpacker.feed(data) | ||||||
|             for packet in unpacker: |             for packet in unpacker: | ||||||
|  | @ -98,7 +100,7 @@ class MsgpackStream: | ||||||
|         return self.stream.socket.fileno() != -1 |         return self.stream.socket.fileno() != -1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MsgspecStream(MsgpackStream): | class MsgspecTCPStream(MsgpackTCPStream): | ||||||
|     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data |     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||||
|     using ``msgspec``. |     using ``msgspec``. | ||||||
| 
 | 
 | ||||||
|  | @ -123,24 +125,22 @@ class MsgspecStream(MsgpackStream): | ||||||
|         while True: |         while True: | ||||||
|             try: |             try: | ||||||
|                 header = await self.recv_stream.receive_exactly(4) |                 header = await self.recv_stream.receive_exactly(4) | ||||||
|                 if header is None: |  | ||||||
|                     continue |  | ||||||
| 
 | 
 | ||||||
|                 if header == b'': |             except (ValueError): | ||||||
|                     log.debug(f"Stream connection {self.raddr} was closed") |                 raise TransportClosed( | ||||||
|                     return |                     f'transport {self} was already closed prior ro read' | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|                 size, = struct.unpack("<I", header) |             if header == b'': | ||||||
|  |                 raise TransportClosed( | ||||||
|  |                     f'transport {self} was already closed prior ro read' | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|                 log.trace(f'received header {size}') |             size, = struct.unpack("<I", header) | ||||||
| 
 | 
 | ||||||
|                 msg_bytes = await self.recv_stream.receive_exactly(size) |             log.trace(f'received header {size}') | ||||||
| 
 | 
 | ||||||
|             # the value error here is to catch a connect with immediate |             msg_bytes = await self.recv_stream.receive_exactly(size) | ||||||
|             # disconnect that will cause an EOF error inside `tricycle`. |  | ||||||
|             except (ValueError, trio.BrokenResourceError): |  | ||||||
|                 log.warning(f"Stream connection {self.raddr} broke") |  | ||||||
|                 return |  | ||||||
| 
 | 
 | ||||||
|             log.trace(f"received {msg_bytes}")  # type: ignore |             log.trace(f"received {msg_bytes}")  # type: ignore | ||||||
|             yield decoder.decode(msg_bytes) |             yield decoder.decode(msg_bytes) | ||||||
|  | @ -169,8 +169,9 @@ class Channel: | ||||||
|         on_reconnect: typing.Callable[..., typing.Awaitable] = None, |         on_reconnect: typing.Callable[..., typing.Awaitable] = None, | ||||||
|         auto_reconnect: bool = False, |         auto_reconnect: bool = False, | ||||||
|         stream: trio.SocketStream = None,  # expected to be active |         stream: trio.SocketStream = None,  # expected to be active | ||||||
|         # stream_serializer: type = MsgpackStream, | 
 | ||||||
|         stream_serializer_type: type = MsgspecStream, |         # stream_serializer_type: type = MsgspecTCPStream, | ||||||
|  |         stream_serializer_type: type = MsgpackTCPStream, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|  | @ -192,6 +193,8 @@ class Channel: | ||||||
|         self._exc: Optional[Exception] = None |         self._exc: Optional[Exception] = None | ||||||
|         self._agen = self._aiter_recv() |         self._agen = self._aiter_recv() | ||||||
| 
 | 
 | ||||||
|  |         self._closed: bool = False | ||||||
|  | 
 | ||||||
|     def __repr__(self) -> str: |     def __repr__(self) -> str: | ||||||
|         if self.msgstream: |         if self.msgstream: | ||||||
|             return repr( |             return repr( | ||||||
|  | @ -208,35 +211,52 @@ class Channel: | ||||||
|         return self.msgstream.raddr if self.msgstream else None |         return self.msgstream.raddr if self.msgstream else None | ||||||
| 
 | 
 | ||||||
|     async def connect( |     async def connect( | ||||||
|         self, destaddr: Tuple[Any, ...] = None, |         self, | ||||||
|  |         destaddr: Tuple[Any, ...] = None, | ||||||
|         **kwargs |         **kwargs | ||||||
|  | 
 | ||||||
|     ) -> trio.SocketStream: |     ) -> trio.SocketStream: | ||||||
|  | 
 | ||||||
|         if self.connected(): |         if self.connected(): | ||||||
|             raise RuntimeError("channel is already connected?") |             raise RuntimeError("channel is already connected?") | ||||||
|  | 
 | ||||||
|         destaddr = destaddr or self._destaddr |         destaddr = destaddr or self._destaddr | ||||||
|         assert isinstance(destaddr, tuple) |         assert isinstance(destaddr, tuple) | ||||||
|         stream = await trio.open_tcp_stream(*destaddr, **kwargs) | 
 | ||||||
|  |         stream = await trio.open_tcp_stream( | ||||||
|  |             *destaddr, | ||||||
|  |             happy_eyeballs_delay=math.inf, | ||||||
|  |             **kwargs | ||||||
|  |         ) | ||||||
|         self.msgstream = self.stream_serializer_type(stream) |         self.msgstream = self.stream_serializer_type(stream) | ||||||
|         return stream |         return stream | ||||||
| 
 | 
 | ||||||
|     async def send(self, item: Any) -> None: |     async def send(self, item: Any) -> None: | ||||||
|  | 
 | ||||||
|         log.trace(f"send `{item}`")  # type: ignore |         log.trace(f"send `{item}`")  # type: ignore | ||||||
|         assert self.msgstream |         assert self.msgstream | ||||||
|  | 
 | ||||||
|         await self.msgstream.send(item) |         await self.msgstream.send(item) | ||||||
| 
 | 
 | ||||||
|     async def recv(self) -> Any: |     async def recv(self) -> Any: | ||||||
|         assert self.msgstream |         assert self.msgstream | ||||||
|  | 
 | ||||||
|         try: |         try: | ||||||
|             return await self.msgstream.recv() |             return await self.msgstream.recv() | ||||||
|  | 
 | ||||||
|         except trio.BrokenResourceError: |         except trio.BrokenResourceError: | ||||||
|             if self._autorecon: |             if self._autorecon: | ||||||
|                 await self._reconnect() |                 await self._reconnect() | ||||||
|                 return await self.recv() |                 return await self.recv() | ||||||
| 
 | 
 | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|     async def aclose(self) -> None: |     async def aclose(self) -> None: | ||||||
|         log.debug(f"Closing {self}") |         log.debug(f"Closing {self}") | ||||||
|         assert self.msgstream |         assert self.msgstream | ||||||
|         await self.msgstream.stream.aclose() |         await self.msgstream.stream.aclose() | ||||||
|  |         self._closed = True | ||||||
|  |         log.error(f'CLOSING CHAN {self}') | ||||||
| 
 | 
 | ||||||
|     async def __aenter__(self): |     async def __aenter__(self): | ||||||
|         await self.connect() |         await self.connect() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue