Drop `msgpack` lib and use `msgspec` for transport
							parent
							
								
									f6af5c7bf8
								
							
						
					
					
						commit
						f94b7cd991
					
				
							
								
								
									
										144
									
								
								tractor/_ipc.py
								
								
								
								
							
							
						
						
									
										144
									
								
								tractor/_ipc.py
								
								
								
								
							|  | @ -29,7 +29,7 @@ from typing import ( | |||
| ) | ||||
| 
 | ||||
| from tricycle import BufferedReceiveStream | ||||
| import msgpack | ||||
| import msgspec | ||||
| import trio | ||||
| from async_generator import asynccontextmanager | ||||
| 
 | ||||
|  | @ -98,12 +98,13 @@ class MsgTransport(Protocol[MsgType]): | |||
| class MsgpackTCPStream: | ||||
|     ''' | ||||
|     A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||
|     using ``msgpack-python``. | ||||
|     using the ``msgspec`` codec lib. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         stream: trio.SocketStream, | ||||
|         prefix_size: int = 4, | ||||
| 
 | ||||
|     ) -> None: | ||||
| 
 | ||||
|  | @ -120,105 +121,6 @@ class MsgpackTCPStream: | |||
|         # public i guess? | ||||
|         self.drained: list[dict] = [] | ||||
| 
 | ||||
|     async def _iter_packets(self) -> AsyncGenerator[dict, None]: | ||||
|         ''' | ||||
|         Yield packets from the underlying stream. | ||||
| 
 | ||||
|         ''' | ||||
|         unpacker = msgpack.Unpacker( | ||||
|             raw=False, | ||||
|         ) | ||||
|         while True: | ||||
|             try: | ||||
|                 data = await self.stream.receive_some(2**10) | ||||
| 
 | ||||
|             except trio.BrokenResourceError as err: | ||||
|                 msg = err.args[0] | ||||
| 
 | ||||
|                 # XXX: handle connection-reset-by-peer the same as a EOF. | ||||
|                 # we're currently remapping this since we allow | ||||
|                 # a quick connect then drop for root actors when | ||||
|                 # checking to see if there exists an "arbiter" | ||||
|                 # on the chosen sockaddr (``_root.py:108`` or thereabouts) | ||||
|                 if ( | ||||
|                     # nix | ||||
|                     '[Errno 104]' in msg or | ||||
| 
 | ||||
|                     # on windows it seems there are a variety of errors | ||||
|                     # to handle.. | ||||
|                     _is_windows | ||||
|                 ): | ||||
|                     raise TransportClosed( | ||||
|                         f'{self} was broken with {msg}' | ||||
|                     ) | ||||
| 
 | ||||
|                 else: | ||||
|                     raise | ||||
| 
 | ||||
|             log.transport(f"received {data}")  # type: ignore | ||||
| 
 | ||||
|             if data == b'': | ||||
|                 raise TransportClosed( | ||||
|                     f'transport {self} was already closed prior to read' | ||||
|                 ) | ||||
| 
 | ||||
|             unpacker.feed(data) | ||||
|             for packet in unpacker: | ||||
|                 yield packet | ||||
| 
 | ||||
|     @property | ||||
|     def laddr(self) -> Tuple[Any, ...]: | ||||
|         return self._laddr | ||||
| 
 | ||||
|     @property | ||||
|     def raddr(self) -> Tuple[Any, ...]: | ||||
|         return self._raddr | ||||
| 
 | ||||
|     async def send(self, msg: Any) -> None: | ||||
|         async with self._send_lock: | ||||
|             return await self.stream.send_all( | ||||
|                 msgpack.dumps(msg, use_bin_type=True) | ||||
|             ) | ||||
| 
 | ||||
|     async def recv(self) -> Any: | ||||
|         return await self._agen.asend(None) | ||||
| 
 | ||||
|     async def drain(self) -> AsyncIterator[dict]: | ||||
|         ''' | ||||
|         Drain the stream's remaining messages sent from | ||||
|         the far end until the connection is closed by | ||||
|         the peer. | ||||
| 
 | ||||
|         ''' | ||||
|         try: | ||||
|             async for msg in self._iter_packets(): | ||||
|                 self.drained.append(msg) | ||||
|         except TransportClosed: | ||||
|             for msg in self.drained: | ||||
|                 yield msg | ||||
| 
 | ||||
|     def __aiter__(self): | ||||
|         return self._agen | ||||
| 
 | ||||
|     def connected(self) -> bool: | ||||
|         return self.stream.socket.fileno() != -1 | ||||
| 
 | ||||
| 
 | ||||
| class MsgspecTCPStream(MsgpackTCPStream): | ||||
|     ''' | ||||
|     A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||
|     using ``msgspec``. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         stream: trio.SocketStream, | ||||
|         prefix_size: int = 4, | ||||
| 
 | ||||
|     ) -> None: | ||||
|         import msgspec | ||||
| 
 | ||||
|         super().__init__(stream) | ||||
|         self.recv_stream = BufferedReceiveStream(transport_stream=stream) | ||||
|         self.prefix_size = prefix_size | ||||
| 
 | ||||
|  | @ -287,6 +189,37 @@ class MsgspecTCPStream(MsgpackTCPStream): | |||
| 
 | ||||
|             return await self.stream.send_all(size + bytes_data) | ||||
| 
 | ||||
|     @property | ||||
|     def laddr(self) -> Tuple[Any, ...]: | ||||
|         return self._laddr | ||||
| 
 | ||||
|     @property | ||||
|     def raddr(self) -> Tuple[Any, ...]: | ||||
|         return self._raddr | ||||
| 
 | ||||
|     async def recv(self) -> Any: | ||||
|         return await self._agen.asend(None) | ||||
| 
 | ||||
|     async def drain(self) -> AsyncIterator[dict]: | ||||
|         ''' | ||||
|         Drain the stream's remaining messages sent from | ||||
|         the far end until the connection is closed by | ||||
|         the peer. | ||||
| 
 | ||||
|         ''' | ||||
|         try: | ||||
|             async for msg in self._iter_packets(): | ||||
|                 self.drained.append(msg) | ||||
|         except TransportClosed: | ||||
|             for msg in self.drained: | ||||
|                 yield msg | ||||
| 
 | ||||
|     def __aiter__(self): | ||||
|         return self._agen | ||||
| 
 | ||||
|     def connected(self) -> bool: | ||||
|         return self.stream.socket.fileno() != -1 | ||||
| 
 | ||||
| 
 | ||||
| def get_msg_transport( | ||||
| 
 | ||||
|  | @ -296,7 +229,6 @@ def get_msg_transport( | |||
| 
 | ||||
|     return { | ||||
|         ('msgpack', 'tcp'): MsgpackTCPStream, | ||||
|         ('msgspec', 'tcp'): MsgspecTCPStream, | ||||
|     }[key] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -325,14 +257,6 @@ class Channel: | |||
|         # self._recon_seq = on_reconnect | ||||
|         # self._autorecon = auto_reconnect | ||||
| 
 | ||||
|         # TODO: maybe expose this through the nursery api? | ||||
|         try: | ||||
|             # if installed load the msgspec transport since it's faster | ||||
|             import msgspec  # noqa | ||||
|             msg_transport_type_key = ('msgspec', 'tcp') | ||||
|         except ImportError: | ||||
|             pass | ||||
| 
 | ||||
|         self._destaddr = destaddr | ||||
|         self._transport_key = msg_transport_type_key | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue