forked from goodboy/tractor
				
			Add a stream type factory
							parent
							
								
									6cf4a80fe4
								
							
						
					
					
						commit
						0d41f1410f
					
				|  | @ -5,7 +5,7 @@ Inter-process comms abstractions | ||||||
| import platform | import platform | ||||||
| import struct | import struct | ||||||
| import typing | import typing | ||||||
| from typing import Any, Tuple, Optional | from typing import Any, Tuple, Optional, Type | ||||||
| 
 | 
 | ||||||
| from tricycle import BufferedReceiveStream | from tricycle import BufferedReceiveStream | ||||||
| import msgpack | import msgpack | ||||||
|  | @ -55,6 +55,7 @@ class MsgpackTCPStream: | ||||||
|         unpacker = msgpack.Unpacker( |         unpacker = msgpack.Unpacker( | ||||||
|             raw=False, |             raw=False, | ||||||
|             use_list=False, |             use_list=False, | ||||||
|  |             strict_map_key=False | ||||||
|         ) |         ) | ||||||
|         while True: |         while True: | ||||||
|             try: |             try: | ||||||
|  | @ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream): | ||||||
|         prefix_size: int = 4, |         prefix_size: int = 4, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         import msgspec | ||||||
|  | 
 | ||||||
|         super().__init__(stream) |         super().__init__(stream) | ||||||
|         self.recv_stream = BufferedReceiveStream(transport_stream=stream) |         self.recv_stream = BufferedReceiveStream(transport_stream=stream) | ||||||
|         self.prefix_size = prefix_size |         self.prefix_size = prefix_size | ||||||
| 
 | 
 | ||||||
|         import msgspec |  | ||||||
| 
 |  | ||||||
|         # TODO: struct aware messaging coders |         # TODO: struct aware messaging coders | ||||||
|         self.encode = msgspec.Encoder().encode |         self.encode = msgspec.Encoder().encode | ||||||
|         self.decode = msgspec.Decoder().decode  # dict[str, Any]) |         self.decode = msgspec.Decoder().decode  # dict[str, Any]) | ||||||
|  | @ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream): | ||||||
|                 # ignore decoding errors for now and assume they have to |                 # ignore decoding errors for now and assume they have to | ||||||
|                 # do with a channel drop - hope that receiving from the |                 # do with a channel drop - hope that receiving from the | ||||||
|                 # channel will raise an expected error and bubble up. |                 # channel will raise an expected error and bubble up. | ||||||
|                 log.error(f'`msgspec` failed to decode!?') |                 log.error('`msgspec` failed to decode!?') | ||||||
|                 last_decode_failed = True |                 last_decode_failed = True | ||||||
| 
 | 
 | ||||||
|     async def send(self, data: Any) -> None: |     async def send(self, data: Any) -> None: | ||||||
|  | @ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream): | ||||||
|             return await self.stream.send_all(size + bytes_data) |             return await self.stream.send_all(size + bytes_data) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_serializer_stream_type( | ||||||
|  |     name: str, | ||||||
|  | ) -> Type: | ||||||
|  |     return { | ||||||
|  |         'msgpack': MsgpackTCPStream, | ||||||
|  |         'msgspec': MsgspecTCPStream, | ||||||
|  |     }[name] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Channel: | class Channel: | ||||||
|     """An inter-process channel for communication between (remote) actors. |     '''An inter-process channel for communication between (remote) actors. | ||||||
| 
 | 
 | ||||||
|     Currently the only supported transport is a ``trio.SocketStream``. |     Currently the only supported transport is a ``trio.SocketStream``. | ||||||
|     """ | 
 | ||||||
|  |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
| 
 | 
 | ||||||
|         self, |         self, | ||||||
|  | @ -218,17 +229,17 @@ class Channel: | ||||||
|         self._recon_seq = on_reconnect |         self._recon_seq = on_reconnect | ||||||
|         self._autorecon = auto_reconnect |         self._autorecon = auto_reconnect | ||||||
| 
 | 
 | ||||||
|         stream_serializer_type = MsgpackTCPStream |         # TODO: maybe expose this through the nursery api? | ||||||
| 
 |  | ||||||
|         try: |         try: | ||||||
|             # if installed load the msgspec transport since it's faster |             # if installed load the msgspec transport since it's faster | ||||||
|             import msgspec  # noqa |             import msgspec  # noqa | ||||||
|             stream_serializer_type = MsgspecTCPStream |             serializer = 'msgspec' | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             pass |             serializer = 'msgpack' | ||||||
| 
 | 
 | ||||||
|         self.stream_serializer_type = stream_serializer_type |         self.stream_serializer_type = get_serializer_stream_type(serializer) | ||||||
|         self.msgstream = stream_serializer_type(stream) if stream else None |         self.msgstream = self.stream_serializer_type( | ||||||
|  |             stream) if stream else None | ||||||
| 
 | 
 | ||||||
|         if self.msgstream and destaddr: |         if self.msgstream and destaddr: | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue