forked from goodboy/tractor
				
			Drop `msgpack` lib and use `msgspec` for transport
							parent
							
								
									f6af5c7bf8
								
							
						
					
					
						commit
						f94b7cd991
					
				
							
								
								
									
										144
									
								
								tractor/_ipc.py
								
								
								
								
							
							
						
						
									
										144
									
								
								tractor/_ipc.py
								
								
								
								
							|  | @ -29,7 +29,7 @@ from typing import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from tricycle import BufferedReceiveStream | from tricycle import BufferedReceiveStream | ||||||
| import msgpack | import msgspec | ||||||
| import trio | import trio | ||||||
| from async_generator import asynccontextmanager | from async_generator import asynccontextmanager | ||||||
| 
 | 
 | ||||||
|  | @ -98,12 +98,13 @@ class MsgTransport(Protocol[MsgType]): | ||||||
| class MsgpackTCPStream: | class MsgpackTCPStream: | ||||||
|     ''' |     ''' | ||||||
|     A ``trio.SocketStream`` delivering ``msgpack`` formatted data |     A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||||
|     using ``msgpack-python``. |     using the ``msgspec`` codec lib. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         stream: trio.SocketStream, |         stream: trio.SocketStream, | ||||||
|  |         prefix_size: int = 4, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|  | @ -120,105 +121,6 @@ class MsgpackTCPStream: | ||||||
|         # public i guess? |         # public i guess? | ||||||
|         self.drained: list[dict] = [] |         self.drained: list[dict] = [] | ||||||
| 
 | 
 | ||||||
|     async def _iter_packets(self) -> AsyncGenerator[dict, None]: |  | ||||||
|         ''' |  | ||||||
|         Yield packets from the underlying stream. |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         unpacker = msgpack.Unpacker( |  | ||||||
|             raw=False, |  | ||||||
|         ) |  | ||||||
|         while True: |  | ||||||
|             try: |  | ||||||
|                 data = await self.stream.receive_some(2**10) |  | ||||||
| 
 |  | ||||||
|             except trio.BrokenResourceError as err: |  | ||||||
|                 msg = err.args[0] |  | ||||||
| 
 |  | ||||||
|                 # XXX: handle connection-reset-by-peer the same as a EOF. |  | ||||||
|                 # we're currently remapping this since we allow |  | ||||||
|                 # a quick connect then drop for root actors when |  | ||||||
|                 # checking to see if there exists an "arbiter" |  | ||||||
|                 # on the chosen sockaddr (``_root.py:108`` or thereabouts) |  | ||||||
|                 if ( |  | ||||||
|                     # nix |  | ||||||
|                     '[Errno 104]' in msg or |  | ||||||
| 
 |  | ||||||
|                     # on windows it seems there are a variety of errors |  | ||||||
|                     # to handle.. |  | ||||||
|                     _is_windows |  | ||||||
|                 ): |  | ||||||
|                     raise TransportClosed( |  | ||||||
|                         f'{self} was broken with {msg}' |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|                 else: |  | ||||||
|                     raise |  | ||||||
| 
 |  | ||||||
|             log.transport(f"received {data}")  # type: ignore |  | ||||||
| 
 |  | ||||||
|             if data == b'': |  | ||||||
|                 raise TransportClosed( |  | ||||||
|                     f'transport {self} was already closed prior to read' |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|             unpacker.feed(data) |  | ||||||
|             for packet in unpacker: |  | ||||||
|                 yield packet |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def laddr(self) -> Tuple[Any, ...]: |  | ||||||
|         return self._laddr |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def raddr(self) -> Tuple[Any, ...]: |  | ||||||
|         return self._raddr |  | ||||||
| 
 |  | ||||||
|     async def send(self, msg: Any) -> None: |  | ||||||
|         async with self._send_lock: |  | ||||||
|             return await self.stream.send_all( |  | ||||||
|                 msgpack.dumps(msg, use_bin_type=True) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|     async def recv(self) -> Any: |  | ||||||
|         return await self._agen.asend(None) |  | ||||||
| 
 |  | ||||||
|     async def drain(self) -> AsyncIterator[dict]: |  | ||||||
|         ''' |  | ||||||
|         Drain the stream's remaining messages sent from |  | ||||||
|         the far end until the connection is closed by |  | ||||||
|         the peer. |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         try: |  | ||||||
|             async for msg in self._iter_packets(): |  | ||||||
|                 self.drained.append(msg) |  | ||||||
|         except TransportClosed: |  | ||||||
|             for msg in self.drained: |  | ||||||
|                 yield msg |  | ||||||
| 
 |  | ||||||
|     def __aiter__(self): |  | ||||||
|         return self._agen |  | ||||||
| 
 |  | ||||||
|     def connected(self) -> bool: |  | ||||||
|         return self.stream.socket.fileno() != -1 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MsgspecTCPStream(MsgpackTCPStream): |  | ||||||
|     ''' |  | ||||||
|     A ``trio.SocketStream`` delivering ``msgpack`` formatted data |  | ||||||
|     using ``msgspec``. |  | ||||||
| 
 |  | ||||||
|     ''' |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         stream: trio.SocketStream, |  | ||||||
|         prefix_size: int = 4, |  | ||||||
| 
 |  | ||||||
|     ) -> None: |  | ||||||
|         import msgspec |  | ||||||
| 
 |  | ||||||
|         super().__init__(stream) |  | ||||||
|         self.recv_stream = BufferedReceiveStream(transport_stream=stream) |         self.recv_stream = BufferedReceiveStream(transport_stream=stream) | ||||||
|         self.prefix_size = prefix_size |         self.prefix_size = prefix_size | ||||||
| 
 | 
 | ||||||
|  | @ -287,6 +189,37 @@ class MsgspecTCPStream(MsgpackTCPStream): | ||||||
| 
 | 
 | ||||||
|             return await self.stream.send_all(size + bytes_data) |             return await self.stream.send_all(size + bytes_data) | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def laddr(self) -> Tuple[Any, ...]: | ||||||
|  |         return self._laddr | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def raddr(self) -> Tuple[Any, ...]: | ||||||
|  |         return self._raddr | ||||||
|  | 
 | ||||||
|  |     async def recv(self) -> Any: | ||||||
|  |         return await self._agen.asend(None) | ||||||
|  | 
 | ||||||
|  |     async def drain(self) -> AsyncIterator[dict]: | ||||||
|  |         ''' | ||||||
|  |         Drain the stream's remaining messages sent from | ||||||
|  |         the far end until the connection is closed by | ||||||
|  |         the peer. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         try: | ||||||
|  |             async for msg in self._iter_packets(): | ||||||
|  |                 self.drained.append(msg) | ||||||
|  |         except TransportClosed: | ||||||
|  |             for msg in self.drained: | ||||||
|  |                 yield msg | ||||||
|  | 
 | ||||||
|  |     def __aiter__(self): | ||||||
|  |         return self._agen | ||||||
|  | 
 | ||||||
|  |     def connected(self) -> bool: | ||||||
|  |         return self.stream.socket.fileno() != -1 | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def get_msg_transport( | def get_msg_transport( | ||||||
| 
 | 
 | ||||||
|  | @ -296,7 +229,6 @@ def get_msg_transport( | ||||||
| 
 | 
 | ||||||
|     return { |     return { | ||||||
|         ('msgpack', 'tcp'): MsgpackTCPStream, |         ('msgpack', 'tcp'): MsgpackTCPStream, | ||||||
|         ('msgspec', 'tcp'): MsgspecTCPStream, |  | ||||||
|     }[key] |     }[key] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -325,14 +257,6 @@ class Channel: | ||||||
|         # self._recon_seq = on_reconnect |         # self._recon_seq = on_reconnect | ||||||
|         # self._autorecon = auto_reconnect |         # self._autorecon = auto_reconnect | ||||||
| 
 | 
 | ||||||
|         # TODO: maybe expose this through the nursery api? |  | ||||||
|         try: |  | ||||||
|             # if installed load the msgspec transport since it's faster |  | ||||||
|             import msgspec  # noqa |  | ||||||
|             msg_transport_type_key = ('msgspec', 'tcp') |  | ||||||
|         except ImportError: |  | ||||||
|             pass |  | ||||||
| 
 |  | ||||||
|         self._destaddr = destaddr |         self._destaddr = destaddr | ||||||
|         self._transport_key = msg_transport_type_key |         self._transport_key = msg_transport_type_key | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue