forked from goodboy/tractor
1
0
Fork 0

Add our own "transport closed" signal

This change some super old (and bad) code from the project's very early
days. For some redic reason i must have thought masking `trio`'s
internal stream / transport errors and a TCP EOF as `StopAsyncIteration`
somehow a good idea. The reality is you probably
want to know the difference between an unexpected transport error
and a simple EOF lol. This begins to resolve that by adding our own
special `TransportClosed` error to signal the "graceful" termination of
a channel's underlying transport. Oh, and this builds on the `msgspec`
integration which helped shed light on the core issues here B)
prehardkill
Tyler Goodlet 2021-06-24 18:49:51 -04:00
parent 62ece4327d
commit e6c9232b45
2 changed files with 54 additions and 27 deletions

View File

@ -41,6 +41,10 @@ class ContextCancelled(RemoteActorError):
"Inter-actor task context cancelled itself on the callee side." "Inter-actor task context cancelled itself on the callee side."
class TransportClosed(trio.ClosedResourceError):
"Underlying channel transport was closed prior to use"
class NoResult(RuntimeError): class NoResult(RuntimeError):
"No final result is expected for this actor" "No final result is expected for this actor"
@ -66,12 +70,15 @@ def pack_error(exc: BaseException) -> Dict[str, Any]:
def unpack_error( def unpack_error(
msg: Dict[str, Any], msg: Dict[str, Any],
chan=None, chan=None,
err_type=RemoteActorError err_type=RemoteActorError
) -> Exception: ) -> Exception:
"""Unpack an 'error' message from the wire """Unpack an 'error' message from the wire
into a local ``RemoteActorError``. into a local ``RemoteActorError``.
""" """
error = msg['error'] error = msg['error']

View File

@ -2,6 +2,7 @@
Inter-process comms abstractions Inter-process comms abstractions
""" """
from functools import partial from functools import partial
import math
import struct import struct
import typing import typing
from typing import Any, Tuple, Optional from typing import Any, Tuple, Optional
@ -13,6 +14,7 @@ import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
from .log import get_logger from .log import get_logger
from ._exceptions import TransportClosed
log = get_logger(__name__) log = get_logger(__name__)
# :eyeroll: # :eyeroll:
@ -24,7 +26,7 @@ except ImportError:
Unpacker = partial(msgpack.Unpacker, strict_map_key=False) Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
class MsgpackStream: class MsgpackTCPStream:
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``. using ``msgpack-python``.
@ -47,7 +49,10 @@ class MsgpackStream:
assert isinstance(rsockname, tuple) assert isinstance(rsockname, tuple)
self._raddr = rsockname[:2] self._raddr = rsockname[:2]
# start and seed first entry to read loop
self._agen = self._iter_packets() self._agen = self._iter_packets()
# self._agen.asend(None) is None
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
@ -58,16 +63,13 @@ class MsgpackStream:
use_list=False, use_list=False,
) )
while True: while True:
try:
data = await self.stream.receive_some(2**10) data = await self.stream.receive_some(2**10)
log.trace(f"received {data}") # type: ignore log.trace(f"received {data}") # type: ignore
except trio.BrokenResourceError:
log.warning(f"Stream connection {self.raddr} broke")
return
if data == b'': if data == b'':
log.debug(f"Stream connection {self.raddr} was closed") raise TransportClosed(
return f'transport {self} was already closed prior ro read'
)
unpacker.feed(data) unpacker.feed(data)
for packet in unpacker: for packet in unpacker:
@ -98,7 +100,7 @@ class MsgpackStream:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
class MsgspecStream(MsgpackStream): class MsgspecTCPStream(MsgpackTCPStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``. using ``msgspec``.
@ -123,12 +125,16 @@ class MsgspecStream(MsgpackStream):
while True: while True:
try: try:
header = await self.recv_stream.receive_exactly(4) header = await self.recv_stream.receive_exactly(4)
if header is None:
continue except (ValueError):
raise TransportClosed(
f'transport {self} was already closed prior ro read'
)
if header == b'': if header == b'':
log.debug(f"Stream connection {self.raddr} was closed") raise TransportClosed(
return f'transport {self} was already closed prior ro read'
)
size, = struct.unpack("<I", header) size, = struct.unpack("<I", header)
@ -136,12 +142,6 @@ class MsgspecStream(MsgpackStream):
msg_bytes = await self.recv_stream.receive_exactly(size) msg_bytes = await self.recv_stream.receive_exactly(size)
# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return
log.trace(f"received {msg_bytes}") # type: ignore log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes) yield decoder.decode(msg_bytes)
@ -169,8 +169,9 @@ class Channel:
on_reconnect: typing.Callable[..., typing.Awaitable] = None, on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False, auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream, # stream_serializer_type: type = MsgspecTCPStream,
stream_serializer_type: type = MsgpackTCPStream,
) -> None: ) -> None:
@ -192,6 +193,8 @@ class Channel:
self._exc: Optional[Exception] = None self._exc: Optional[Exception] = None
self._agen = self._aiter_recv() self._agen = self._aiter_recv()
self._closed: bool = False
def __repr__(self) -> str: def __repr__(self) -> str:
if self.msgstream: if self.msgstream:
return repr( return repr(
@ -208,35 +211,52 @@ class Channel:
return self.msgstream.raddr if self.msgstream else None return self.msgstream.raddr if self.msgstream else None
async def connect( async def connect(
self, destaddr: Tuple[Any, ...] = None, self,
destaddr: Tuple[Any, ...] = None,
**kwargs **kwargs
) -> trio.SocketStream: ) -> trio.SocketStream:
if self.connected(): if self.connected():
raise RuntimeError("channel is already connected?") raise RuntimeError("channel is already connected?")
destaddr = destaddr or self._destaddr destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple) assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(*destaddr, **kwargs)
stream = await trio.open_tcp_stream(
*destaddr,
happy_eyeballs_delay=math.inf,
**kwargs
)
self.msgstream = self.stream_serializer_type(stream) self.msgstream = self.stream_serializer_type(stream)
return stream return stream
async def send(self, item: Any) -> None: async def send(self, item: Any) -> None:
log.trace(f"send `{item}`") # type: ignore log.trace(f"send `{item}`") # type: ignore
assert self.msgstream assert self.msgstream
await self.msgstream.send(item) await self.msgstream.send(item)
async def recv(self) -> Any: async def recv(self) -> Any:
assert self.msgstream assert self.msgstream
try: try:
return await self.msgstream.recv() return await self.msgstream.recv()
except trio.BrokenResourceError: except trio.BrokenResourceError:
if self._autorecon: if self._autorecon:
await self._reconnect() await self._reconnect()
return await self.recv() return await self.recv()
raise
async def aclose(self) -> None: async def aclose(self) -> None:
log.debug(f"Closing {self}") log.debug(f"Closing {self}")
assert self.msgstream assert self.msgstream
await self.msgstream.stream.aclose() await self.msgstream.stream.aclose()
self._closed = True
log.error(f'CLOSING CHAN {self}')
async def __aenter__(self): async def __aenter__(self):
await self.connect() await self.connect()