Switch `tractor.ipc.MsgTransport.stream` type to `trio.abc.Stream`

Add EOF signaling mechanism
Support proper `receive_some` end of stream semantics
Add StapledStream non-ipc test
Create MsgpackRBStream similar to MsgpackTCPStream for buffered whole-msg reads
Add EventFD.read cancellation on EventFD.close mechanism using cancel scope
Add test for eventfd cancellation
Improve and add docstrings
Guillermo Rodriguez 2025-03-16 23:57:26 -03:00
parent d6721f06df
commit 5cec4ee943
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
8 changed files with 671 additions and 75 deletions

View File

@ -0,0 +1,32 @@
import trio
import pytest
from tractor.ipc import (
open_eventfd,
EFDReadCancelled,
EventFD
)
def test_eventfd_read_cancellation():
'''
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
is called.
'''
fd = open_eventfd()
async def _read(event: EventFD):
with pytest.raises(EFDReadCancelled):
await event.read()
async def main():
async with trio.open_nursery() as n:
with (
EventFD(fd, 'w') as event,
trio.fail_after(3)
):
n.start_soon(_read, event)
await trio.sleep(0.2)
event.close()
trio.run(main)

View File

@ -5,11 +5,16 @@ import pytest
import tractor import tractor
from tractor.ipc import ( from tractor.ipc import (
open_ringbuf, open_ringbuf,
attach_to_ringbuf_receiver,
attach_to_ringbuf_sender,
attach_to_ringbuf_pair,
attach_to_ringbuf_stream,
RBToken, RBToken,
RingBuffSender,
RingBuffReceiver
) )
from tractor._testing.samples import generate_sample_messages from tractor._testing.samples import (
generate_single_byte_msgs,
generate_sample_messages
)
@tractor.context @tractor.context
@ -17,20 +22,14 @@ async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int, msg_amount: int,
token: RBToken, token: RBToken,
total_bytes: int,
) -> None: ) -> None:
recvd_bytes = 0 recvd_bytes = 0
await ctx.started() await ctx.started()
start_ts = time.time() start_ts = time.time()
async with RingBuffReceiver(token) as receiver: async with attach_to_ringbuf_receiver(token) as receiver:
while recvd_bytes < total_bytes: async for msg in receiver:
msg = await receiver.receive_some()
recvd_bytes += len(msg) recvd_bytes += len(msg)
# make sure we dont hold any memoryviews
# before the ctx manager aclose()
msg = None
end_ts = time.time() end_ts = time.time()
elapsed = end_ts - start_ts elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000) elapsed_ms = int(elapsed * 1000)
@ -38,6 +37,7 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}') print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived bytes: {recvd_bytes}')
@tractor.context @tractor.context
@ -54,7 +54,7 @@ async def child_write_shm(
rand_max=rand_max, rand_max=rand_max,
) )
await ctx.started(total_bytes) await ctx.started(total_bytes)
async with RingBuffSender(token) as sender: async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in msgs:
await sender.send_all(msg) await sender.send_all(msg)
@ -99,14 +99,8 @@ def test_ringbuf(
'test_ringbuf', 'test_ringbuf',
buf_size=buf_size buf_size=buf_size
) as token: ) as token:
proc_kwargs = { proc_kwargs = {'pass_fds': token.fds}
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
common_kwargs = {
'msg_amount': msg_amount,
'token': token,
}
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
send_p = await an.start_actor( send_p = await an.start_actor(
'ring_sender', 'ring_sender',
@ -121,14 +115,15 @@ def test_ringbuf(
async with ( async with (
send_p.open_context( send_p.open_context(
child_write_shm, child_write_shm,
token=token,
msg_amount=msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes), ) as (sctx, total_bytes),
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
**common_kwargs, token=token,
total_bytes=total_bytes, msg_amount=msg_amount
) as (sctx, _sent), ) as (sctx, _sent),
): ):
await recv_p.result() await recv_p.result()
@ -145,7 +140,7 @@ async def child_blocked_receiver(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with RingBuffReceiver(token) as receiver: async with attach_to_ringbuf_receiver(token) as receiver:
await ctx.started() await ctx.started()
await receiver.receive_some() await receiver.receive_some()
@ -160,13 +155,13 @@ def test_ring_reader_cancel():
with open_ringbuf('test_ring_cancel_reader') as token: with open_ringbuf('test_ring_cancel_reader') as token:
async with ( async with (
tractor.open_nursery() as an, tractor.open_nursery() as an,
RingBuffSender(token) as _sender, attach_to_ringbuf_sender(token) as _sender,
): ):
recv_p = await an.start_actor( recv_p = await an.start_actor(
'ring_blocked_receiver', 'ring_blocked_receiver',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd) 'pass_fds': token.fds
} }
) )
async with ( async with (
@ -188,7 +183,7 @@ async def child_blocked_sender(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with RingBuffSender(token) as sender: async with attach_to_ringbuf_sender(token) as sender:
await ctx.started() await ctx.started()
await sender.send_all(b'this will wrap') await sender.send_all(b'this will wrap')
@ -209,7 +204,7 @@ def test_ring_sender_cancel():
'ring_blocked_sender', 'ring_blocked_sender',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd) 'pass_fds': token.fds
} }
) )
async with ( async with (
@ -235,7 +230,7 @@ def test_ringbuf_max_bytes():
msgs with original message msgs with original message
''' '''
msg = b''.join(str(i % 10).encode() for i in range(100)) msg = generate_single_byte_msgs(100)
msgs = [] msgs = []
async def main(): async def main():
@ -245,15 +240,153 @@ def test_ringbuf_max_bytes():
) as token: ) as token:
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
RingBuffSender(token, is_ipc=False) as sender, attach_to_ringbuf_sender(token, cleanup=False) as sender,
RingBuffReceiver(token, is_ipc=False) as receiver attach_to_ringbuf_receiver(token, cleanup=False) as receiver
): ):
n.start_soon(sender.send_all, msg) async def _send_and_close():
await sender.send_all(msg)
await sender.aclose()
n.start_soon(_send_and_close)
while len(msgs) < len(msg): while len(msgs) < len(msg):
msg_part = await receiver.receive_some(max_bytes=1) msg_part = await receiver.receive_some(max_bytes=1)
msg_part = bytes(msg_part)
assert len(msg_part) == 1 assert len(msg_part) == 1
msgs.append(msg_part) msgs.append(msg_part)
trio.run(main) trio.run(main)
assert msg == b''.join(msgs) assert msg == b''.join(msgs)
def test_stapled_ringbuf():
'''
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
are inversed on each task) which will open the streams and use trio.StapledStream
to have a single bidirectional stream.
Then take turns to send and receive messages.
'''
msg = generate_single_byte_msgs(100)
pair_0_msgs = []
pair_1_msgs = []
pair_0_done = trio.Event()
pair_1_done = trio.Event()
async def pair_0(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_pair(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to send
await stream.send_all(msg)
# second turn to receive
while len(pair_0_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_0_msgs.append(_msg)
pair_0_done.set()
await pair_1_done.wait()
async def pair_1(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_pair(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to receive
while len(pair_1_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_1_msgs.append(_msg)
# second turn to send
await stream.send_all(msg)
pair_1_done.set()
await pair_0_done.wait()
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_stapled_ringbuf'
) as (token_0, token_1):
async with trio.open_nursery() as n:
n.start_soon(pair_0, token_0, token_1)
n.start_soon(pair_1, token_1, token_0)
trio.run(main)
assert msg == b''.join(pair_0_msgs)
assert msg == b''.join(pair_1_msgs)
@tractor.context
async def child_transport_sender(
ctx: tractor.Context,
msg_amount_min: int,
msg_amount_max: int,
token_in: RBToken,
token_out: RBToken
):
import random
msgs, _total_bytes = generate_sample_messages(
random.randint(msg_amount_min, msg_amount_max),
rand_min=256,
rand_max=1024,
)
async with attach_to_ringbuf_stream(
token_in,
token_out
) as transport:
await ctx.started(msgs)
for msg in msgs:
await transport.send(msg)
await transport.recv()
def test_ringbuf_transport():
msg_amount_min = 100
msg_amount_max = 1000
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_ringbuf_transport'
) as (token_0, token_1):
async with (
attach_to_ringbuf_stream(token_0, token_1) as transport,
tractor.open_nursery() as an
):
recv_p = await an.start_actor(
'test_ringbuf_transport_sender',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': token_0.fds + token_1.fds
}
)
async with (
recv_p.open_context(
child_transport_sender,
msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max,
token_in=token_1,
token_out=token_0
) as (ctx, msgs),
):
recv_msgs = []
while len(recv_msgs) < len(msgs):
recv_msgs.append(await transport.recv())
await transport.send(b'end')
await recv_p.cancel_actor()
assert recv_msgs == msgs
trio.run(main)

View File

@ -2,6 +2,10 @@ import os
import random import random
def generate_single_byte_msgs(amount: int) -> bytes:
return b''.join(str(i % 10).encode() for i in range(amount))
def generate_sample_messages( def generate_sample_messages(
amount: int, amount: int,
rand_min: int = 0, rand_min: int = 0,

View File

@ -39,12 +39,19 @@ if platform.system() == 'Linux':
write_eventfd as write_eventfd, write_eventfd as write_eventfd,
read_eventfd as read_eventfd, read_eventfd as read_eventfd,
close_eventfd as close_eventfd, close_eventfd as close_eventfd,
EFDReadCancelled as EFDReadCancelled,
EventFD as EventFD, EventFD as EventFD,
) )
from ._ringbuf import ( from ._ringbuf import (
RBToken as RBToken, RBToken as RBToken,
open_ringbuf as open_ringbuf,
RingBuffSender as RingBuffSender, RingBuffSender as RingBuffSender,
RingBuffReceiver as RingBuffReceiver, RingBuffReceiver as RingBuffReceiver,
open_ringbuf as open_ringbuf open_ringbuf_pair as open_ringbuf_pair,
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
attach_to_ringbuf_pair as attach_to_ringbuf_pair,
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
MsgpackRBStream as MsgpackRBStream
) )

View File

@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
raise OSError(errno.errorcode[ffi.errno], 'close failed') raise OSError(errno.errorcode[ffi.errno], 'close failed')
class EFDReadCancelled(Exception):
...
class EventFD: class EventFD:
''' '''
Use a previously opened eventfd(2), meant to be used in Use a previously opened eventfd(2), meant to be used in
@ -124,6 +128,7 @@ class EventFD:
self._fd: int = fd self._fd: int = fd
self._omode: str = omode self._omode: str = omode
self._fobj = None self._fobj = None
self._cscope: trio.CancelScope | None = None
@property @property
def fd(self) -> int | None: def fd(self) -> int | None:
@ -133,17 +138,38 @@ class EventFD:
return write_eventfd(self._fd, value) return write_eventfd(self._fd, value)
async def read(self) -> int: async def read(self) -> int:
return await trio.to_thread.run_sync( '''
read_eventfd, self._fd, Async wrapper for `read_eventfd(self.fd)`
abandon_on_cancel=True
) `trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
in order to make it cancellable when `self.close()` is called.
'''
self._cscope = trio.CancelScope()
with self._cscope:
return await trio.to_thread.run_sync(
read_eventfd, self._fd,
abandon_on_cancel=True
)
if self._cscope.cancelled_caught:
raise EFDReadCancelled
self._cscope = None
def open(self): def open(self):
self._fobj = os.fdopen(self._fd, self._omode) self._fobj = os.fdopen(self._fd, self._omode)
def close(self): def close(self):
if self._fobj: if self._fobj:
self._fobj.close() try:
self._fobj.close()
except OSError:
...
if self._cscope:
self._cscope.cancel()
def __enter__(self): def __enter__(self):
self.open() self.open()

View File

@ -18,10 +18,22 @@ IPC Reliable RingBuffer implementation
''' '''
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager as cm import struct
from collections.abc import (
AsyncGenerator,
AsyncIterator
)
from contextlib import (
contextmanager as cm,
asynccontextmanager as acm
)
from typing import (
Any
)
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
import trio import trio
from tricycle import BufferedReceiveStream
from msgspec import ( from msgspec import (
Struct, Struct,
to_builtins to_builtins
@ -30,10 +42,16 @@ from msgspec import (
from ._linux import ( from ._linux import (
open_eventfd, open_eventfd,
close_eventfd, close_eventfd,
EFDReadCancelled,
EventFD EventFD
) )
from ._mp_bs import disable_mantracker from ._mp_bs import disable_mantracker
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import (
TransportClosed,
InternalError
)
from tractor.ipc import MsgTransport
log = get_logger(__name__) log = get_logger(__name__)
@ -41,16 +59,21 @@ log = get_logger(__name__)
disable_mantracker() disable_mantracker()
_DEFAULT_RB_SIZE = 10 * 1024
class RBToken(Struct, frozen=True): class RBToken(Struct, frozen=True):
''' '''
RingBuffer token contains necesary info to open the two RingBuffer token contains necesary info to open the three
eventfds and the shared memory eventfds and the shared memory
''' '''
shm_name: str shm_name: str
write_eventfd: int
wrap_eventfd: int write_eventfd: int # used to signal writer ptr advance
wrap_eventfd: int # used to signal reader ready after wrap around
eof_eventfd: int # used to signal writer closed
buf_size: int buf_size: int
def as_msg(self): def as_msg(self):
@ -63,12 +86,29 @@ class RBToken(Struct, frozen=True):
return RBToken(**msg) return RBToken(**msg)
@property
def fds(self) -> tuple[int, int, int]:
'''
Useful for `pass_fds` params
'''
return (
self.write_eventfd,
self.wrap_eventfd,
self.eof_eventfd
)
@cm @cm
def open_ringbuf( def open_ringbuf(
shm_name: str, shm_name: str,
buf_size: int = 10 * 1024, buf_size: int = _DEFAULT_RB_SIZE,
) -> RBToken: ) -> RBToken:
'''
Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
'''
shm = SharedMemory( shm = SharedMemory(
name=shm_name, name=shm_name,
size=buf_size, size=buf_size,
@ -79,11 +119,27 @@ def open_ringbuf(
shm_name=shm_name, shm_name=shm_name,
write_eventfd=open_eventfd(), write_eventfd=open_eventfd(),
wrap_eventfd=open_eventfd(), wrap_eventfd=open_eventfd(),
eof_eventfd=open_eventfd(),
buf_size=buf_size buf_size=buf_size
) )
yield token yield token
close_eventfd(token.write_eventfd) try:
close_eventfd(token.wrap_eventfd) close_eventfd(token.write_eventfd)
except OSError:
...
try:
close_eventfd(token.wrap_eventfd)
except OSError:
...
try:
close_eventfd(token.eof_eventfd)
except OSError:
...
finally: finally:
shm.unlink() shm.unlink()
@ -91,28 +147,36 @@ def open_ringbuf(
Buffer = bytes | bytearray | memoryview Buffer = bytes | bytearray | memoryview
'''
IPC Reliable Ring Buffer
`eventfd(2)` is used for wrap around sync, to signal writes to
the reader and end of stream.
'''
class RingBuffSender(trio.abc.SendStream): class RingBuffSender(trio.abc.SendStream):
''' '''
IPC Reliable Ring Buffer sender side implementation Ring Buffer sender side implementation
`eventfd(2)` is used for wrap around sync, and also to signal Do not use directly! manage with `attach_to_ringbuf_sender`
writes to the reader. after having opened a ringbuf context with `open_ringbuf`.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
start_ptr: int = 0, cleanup: bool = False
is_ipc: bool = True
): ):
self._token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w') self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._ptr = start_ptr self._eof_event = EventFD(self._token.eof_eventfd, 'w')
self._ptr = 0
self._is_ipc = is_ipc self._cleanup = cleanup
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
@property @property
@ -170,13 +234,21 @@ class RingBuffSender(trio.abc.SendStream):
) )
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open()
async def aclose(self): def close(self):
if self._is_ipc: self._eof_event.write(
self._ptr if self._ptr > 0 else self.size
)
if self._cleanup:
self._write_event.close() self._write_event.close()
self._wrap_event.close() self._wrap_event.close()
self._eof_event.close()
self._shm.close() self._shm.close()
async def aclose(self):
self.close()
async def __aenter__(self): async def __aenter__(self):
self.open() self.open()
return self return self
@ -184,25 +256,27 @@ class RingBuffSender(trio.abc.SendStream):
class RingBuffReceiver(trio.abc.ReceiveStream): class RingBuffReceiver(trio.abc.ReceiveStream):
''' '''
IPC Reliable Ring Buffer receiver side implementation Ring Buffer receiver side implementation
`eventfd(2)` is used for wrap around sync, and also to signal Do not use directly! manage with `attach_to_ringbuf_receiver`
writes to the reader. after having opened a ringbuf context with `open_ringbuf`.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
start_ptr: int = 0, cleanup: bool = True,
is_ipc: bool = True
): ):
self._token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w') self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._ptr = start_ptr self._eof_event = EventFD(self._token.eof_eventfd, 'r')
self._write_ptr = start_ptr self._ptr: int = 0
self._is_ipc = is_ipc self._write_ptr: int = 0
self._end_ptr: int = -1
self._cleanup: bool = cleanup
@property @property
def name(self) -> str: def name(self) -> str:
@ -226,21 +300,71 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def receive_some(self, max_bytes: int | None = None) -> memoryview: async def _eof_monitor_task(self):
'''
Long running EOF event monitor, automatically run in bg by
`attach_to_ringbuf_receiver` context manager, if EOF event
is set its value will be the end pointer (highest valid
index to be read from buf, after setting the `self._end_ptr`
we close the write event which should cancel any blocked
`self._write_event.read()`s on it.
'''
try:
self._end_ptr = await self._eof_event.read()
self._write_event.close()
except EFDReadCancelled:
...
async def receive_some(self, max_bytes: int | None = None) -> bytes:
'''
Receive up to `max_bytes`, if no `max_bytes` is provided
a reasonable default is used.
'''
if max_bytes is None:
max_bytes: int = _DEFAULT_RB_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# delta is remaining bytes we havent read
delta = self._write_ptr - self._ptr delta = self._write_ptr - self._ptr
if delta == 0: if delta == 0:
delta = await self._write_event.read() # we have read all we can, see if new data is available
self._write_ptr += delta if self._end_ptr < 0:
# if we havent been signaled about EOF yet
try:
delta = await self._write_event.read()
self._write_ptr += delta
if isinstance(max_bytes, int): except EFDReadCancelled:
if max_bytes == 0: # while waiting for new data `self._write_event` was closed
raise ValueError('if set, max_bytes must be > 0') # this means writer signaled EOF
delta = min(delta, max_bytes) if self._end_ptr > 0:
# final self._write_ptr modification and recalculate delta
self._write_ptr = self._end_ptr
delta = self._end_ptr - self._ptr
else:
# shouldnt happen cause self._eof_monitor_task always sets
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
else:
# no more bytes to read and self._end_ptr set, EOF reached
return b''
# dont overflow caller
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta target_ptr = self._ptr + delta
# fetch next segment and advance ptr # fetch next segment and advance ptr
segment = self._shm.buf[self._ptr:target_ptr] segment = bytes(self._shm.buf[self._ptr:target_ptr])
self._ptr = target_ptr self._ptr = target_ptr
if self._ptr == self.size: if self._ptr == self.size:
@ -259,13 +383,284 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
) )
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open()
async def aclose(self): def close(self):
if self._is_ipc: if self._cleanup:
self._write_event.close() self._write_event.close()
self._wrap_event.close() self._wrap_event.close()
self._eof_event.close()
self._shm.close() self._shm.close()
async def aclose(self):
self.close()
async def __aenter__(self): async def __aenter__(self):
self.open() self.open()
return self return self
@acm
async def attach_to_ringbuf_receiver(
token: RBToken,
cleanup: bool = True
):
'''
Instantiate a RingBuffReceiver from a previously opened
RBToken.
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
'''
async with (
trio.open_nursery() as n,
RingBuffReceiver(
token,
cleanup=cleanup
) as receiver
):
n.start_soon(receiver._eof_monitor_task)
yield receiver
@acm
async def attach_to_ringbuf_sender(
token: RBToken,
cleanup: bool = True
):
'''
Instantiate a RingBuffSender from a previously opened
RBToken.
'''
async with RingBuffSender(
token,
cleanup=cleanup
) as sender:
yield sender
@cm
def open_ringbuf_pair(
name: str,
buf_size: int = _DEFAULT_RB_SIZE
):
'''
Handle resources for a ringbuf pair to be used for
bidirectional messaging.
'''
with (
open_ringbuf(
name + '.pair0',
buf_size=buf_size
) as token_0,
open_ringbuf(
name + '.pair1',
buf_size=buf_size
) as token_1
):
yield token_0, token_1
@acm
async def attach_to_ringbuf_pair(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
):
'''
Instantiate a trio.StapledStream from a previously opened
ringbuf pair.
'''
async with (
attach_to_ringbuf_receiver(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_sender(
token_out,
cleanup=cleanup_out
) as sender,
):
yield trio.StapledStream(sender, receiver)
class MsgpackRBStream(MsgTransport):
def __init__(
self,
stream: trio.StapledStream
):
self.stream = stream
# create read loop intance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
yield msg_bytes
async def send(
self,
msg: bytes,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg.
'''
async with self._send_lock:
size: bytes = struct.pack("<I", len(msg))
return await self.stream.send_all(size + msg)
async def recv(self) -> Any:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
@acm
async def attach_to_ringbuf_stream(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
):
'''
Wrap a ringbuf trio.StapledStream in a MsgpackRBStream
'''
async with attach_to_ringbuf_pair(
token_in,
token_out,
cleanup_in=cleanup_in,
cleanup_out=cleanup_out
) as stream:
yield MsgpackRBStream(stream)

View File

@ -26,7 +26,6 @@ import struct
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Type,
) )
import msgspec import msgspec

View File

@ -41,10 +41,10 @@ class MsgTransport(Protocol[MsgType]):
# eventual msg definition/types? # eventual msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol # - https://docs.python.org/3/library/typing.html#typing.Protocol
stream: trio.SocketStream stream: trio.abc.Stream
drained: list[MsgType] drained: list[MsgType]
def __init__(self, stream: trio.SocketStream) -> None: def __init__(self, stream: trio.abc.Stream) -> None:
... ...
# XXX: should this instead be called `.sendall()`? # XXX: should this instead be called `.sendall()`?