Switch `tractor.ipc.MsgTransport.stream` type to `trio.abc.Stream`
Add EOF signaling mechanism Support proper `receive_some` end of stream semantics Add StapledStream non-ipc test Create MsgpackRBStream similar to MsgpackTCPStream for buffered whole-msg reads Add EventFD.read cancellation on EventFD.close mechanism using cancel scope Add test for eventfd cancellation Improve and add docstrings
parent
d6721f06df
commit
5cec4ee943
|
@ -0,0 +1,32 @@
|
|||
import trio
|
||||
import pytest
|
||||
from tractor.ipc import (
|
||||
open_eventfd,
|
||||
EFDReadCancelled,
|
||||
EventFD
|
||||
)
|
||||
|
||||
|
||||
def test_eventfd_read_cancellation():
|
||||
'''
|
||||
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
|
||||
is called.
|
||||
|
||||
'''
|
||||
fd = open_eventfd()
|
||||
|
||||
async def _read(event: EventFD):
|
||||
with pytest.raises(EFDReadCancelled):
|
||||
await event.read()
|
||||
|
||||
async def main():
|
||||
async with trio.open_nursery() as n:
|
||||
with (
|
||||
EventFD(fd, 'w') as event,
|
||||
trio.fail_after(3)
|
||||
):
|
||||
n.start_soon(_read, event)
|
||||
await trio.sleep(0.2)
|
||||
event.close()
|
||||
|
||||
trio.run(main)
|
|
@ -5,11 +5,16 @@ import pytest
|
|||
import tractor
|
||||
from tractor.ipc import (
|
||||
open_ringbuf,
|
||||
attach_to_ringbuf_receiver,
|
||||
attach_to_ringbuf_sender,
|
||||
attach_to_ringbuf_pair,
|
||||
attach_to_ringbuf_stream,
|
||||
RBToken,
|
||||
RingBuffSender,
|
||||
RingBuffReceiver
|
||||
)
|
||||
from tractor._testing.samples import generate_sample_messages
|
||||
from tractor._testing.samples import (
|
||||
generate_single_byte_msgs,
|
||||
generate_sample_messages
|
||||
)
|
||||
|
||||
|
||||
@tractor.context
|
||||
|
@ -17,20 +22,14 @@ async def child_read_shm(
|
|||
ctx: tractor.Context,
|
||||
msg_amount: int,
|
||||
token: RBToken,
|
||||
total_bytes: int,
|
||||
) -> None:
|
||||
recvd_bytes = 0
|
||||
await ctx.started()
|
||||
start_ts = time.time()
|
||||
async with RingBuffReceiver(token) as receiver:
|
||||
while recvd_bytes < total_bytes:
|
||||
msg = await receiver.receive_some()
|
||||
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||
async for msg in receiver:
|
||||
recvd_bytes += len(msg)
|
||||
|
||||
# make sure we dont hold any memoryviews
|
||||
# before the ctx manager aclose()
|
||||
msg = None
|
||||
|
||||
end_ts = time.time()
|
||||
elapsed = end_ts - start_ts
|
||||
elapsed_ms = int(elapsed * 1000)
|
||||
|
@ -38,6 +37,7 @@ async def child_read_shm(
|
|||
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||
print(f'\treceived bytes: {recvd_bytes}')
|
||||
|
||||
|
||||
@tractor.context
|
||||
|
@ -54,7 +54,7 @@ async def child_write_shm(
|
|||
rand_max=rand_max,
|
||||
)
|
||||
await ctx.started(total_bytes)
|
||||
async with RingBuffSender(token) as sender:
|
||||
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||
for msg in msgs:
|
||||
await sender.send_all(msg)
|
||||
|
||||
|
@ -99,14 +99,8 @@ def test_ringbuf(
|
|||
'test_ringbuf',
|
||||
buf_size=buf_size
|
||||
) as token:
|
||||
proc_kwargs = {
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
}
|
||||
proc_kwargs = {'pass_fds': token.fds}
|
||||
|
||||
common_kwargs = {
|
||||
'msg_amount': msg_amount,
|
||||
'token': token,
|
||||
}
|
||||
async with tractor.open_nursery() as an:
|
||||
send_p = await an.start_actor(
|
||||
'ring_sender',
|
||||
|
@ -121,14 +115,15 @@ def test_ringbuf(
|
|||
async with (
|
||||
send_p.open_context(
|
||||
child_write_shm,
|
||||
token=token,
|
||||
msg_amount=msg_amount,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
**common_kwargs
|
||||
) as (sctx, total_bytes),
|
||||
recv_p.open_context(
|
||||
child_read_shm,
|
||||
**common_kwargs,
|
||||
total_bytes=total_bytes,
|
||||
token=token,
|
||||
msg_amount=msg_amount
|
||||
) as (sctx, _sent),
|
||||
):
|
||||
await recv_p.result()
|
||||
|
@ -145,7 +140,7 @@ async def child_blocked_receiver(
|
|||
ctx: tractor.Context,
|
||||
token: RBToken
|
||||
):
|
||||
async with RingBuffReceiver(token) as receiver:
|
||||
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||
await ctx.started()
|
||||
await receiver.receive_some()
|
||||
|
||||
|
@ -160,13 +155,13 @@ def test_ring_reader_cancel():
|
|||
with open_ringbuf('test_ring_cancel_reader') as token:
|
||||
async with (
|
||||
tractor.open_nursery() as an,
|
||||
RingBuffSender(token) as _sender,
|
||||
attach_to_ringbuf_sender(token) as _sender,
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
'ring_blocked_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
'pass_fds': token.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
|
@ -188,7 +183,7 @@ async def child_blocked_sender(
|
|||
ctx: tractor.Context,
|
||||
token: RBToken
|
||||
):
|
||||
async with RingBuffSender(token) as sender:
|
||||
async with attach_to_ringbuf_sender(token) as sender:
|
||||
await ctx.started()
|
||||
await sender.send_all(b'this will wrap')
|
||||
|
||||
|
@ -209,7 +204,7 @@ def test_ring_sender_cancel():
|
|||
'ring_blocked_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
'pass_fds': token.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
|
@ -235,7 +230,7 @@ def test_ringbuf_max_bytes():
|
|||
msgs with original message
|
||||
|
||||
'''
|
||||
msg = b''.join(str(i % 10).encode() for i in range(100))
|
||||
msg = generate_single_byte_msgs(100)
|
||||
msgs = []
|
||||
|
||||
async def main():
|
||||
|
@ -245,15 +240,153 @@ def test_ringbuf_max_bytes():
|
|||
) as token:
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBuffSender(token, is_ipc=False) as sender,
|
||||
RingBuffReceiver(token, is_ipc=False) as receiver
|
||||
attach_to_ringbuf_sender(token, cleanup=False) as sender,
|
||||
attach_to_ringbuf_receiver(token, cleanup=False) as receiver
|
||||
):
|
||||
n.start_soon(sender.send_all, msg)
|
||||
async def _send_and_close():
|
||||
await sender.send_all(msg)
|
||||
await sender.aclose()
|
||||
|
||||
n.start_soon(_send_and_close)
|
||||
while len(msgs) < len(msg):
|
||||
msg_part = await receiver.receive_some(max_bytes=1)
|
||||
msg_part = bytes(msg_part)
|
||||
assert len(msg_part) == 1
|
||||
msgs.append(msg_part)
|
||||
|
||||
trio.run(main)
|
||||
assert msg == b''.join(msgs)
|
||||
|
||||
|
||||
def test_stapled_ringbuf():
|
||||
'''
|
||||
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
|
||||
are inversed on each task) which will open the streams and use trio.StapledStream
|
||||
to have a single bidirectional stream.
|
||||
|
||||
Then take turns to send and receive messages.
|
||||
|
||||
'''
|
||||
msg = generate_single_byte_msgs(100)
|
||||
pair_0_msgs = []
|
||||
pair_1_msgs = []
|
||||
|
||||
pair_0_done = trio.Event()
|
||||
pair_1_done = trio.Event()
|
||||
|
||||
async def pair_0(token_in: RBToken, token_out: RBToken):
|
||||
async with attach_to_ringbuf_pair(
|
||||
token_in,
|
||||
token_out,
|
||||
cleanup_in=False,
|
||||
cleanup_out=False
|
||||
) as stream:
|
||||
# first turn to send
|
||||
await stream.send_all(msg)
|
||||
|
||||
# second turn to receive
|
||||
while len(pair_0_msgs) != len(msg):
|
||||
_msg = await stream.receive_some(max_bytes=1)
|
||||
pair_0_msgs.append(_msg)
|
||||
|
||||
pair_0_done.set()
|
||||
await pair_1_done.wait()
|
||||
|
||||
|
||||
async def pair_1(token_in: RBToken, token_out: RBToken):
|
||||
async with attach_to_ringbuf_pair(
|
||||
token_in,
|
||||
token_out,
|
||||
cleanup_in=False,
|
||||
cleanup_out=False
|
||||
) as stream:
|
||||
# first turn to receive
|
||||
while len(pair_1_msgs) != len(msg):
|
||||
_msg = await stream.receive_some(max_bytes=1)
|
||||
pair_1_msgs.append(_msg)
|
||||
|
||||
# second turn to send
|
||||
await stream.send_all(msg)
|
||||
|
||||
pair_1_done.set()
|
||||
await pair_0_done.wait()
|
||||
|
||||
|
||||
async def main():
|
||||
with tractor.ipc.open_ringbuf_pair(
|
||||
'test_stapled_ringbuf'
|
||||
) as (token_0, token_1):
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(pair_0, token_0, token_1)
|
||||
n.start_soon(pair_1, token_1, token_0)
|
||||
|
||||
|
||||
trio.run(main)
|
||||
|
||||
assert msg == b''.join(pair_0_msgs)
|
||||
assert msg == b''.join(pair_1_msgs)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def child_transport_sender(
|
||||
ctx: tractor.Context,
|
||||
msg_amount_min: int,
|
||||
msg_amount_max: int,
|
||||
token_in: RBToken,
|
||||
token_out: RBToken
|
||||
):
|
||||
import random
|
||||
msgs, _total_bytes = generate_sample_messages(
|
||||
random.randint(msg_amount_min, msg_amount_max),
|
||||
rand_min=256,
|
||||
rand_max=1024,
|
||||
)
|
||||
async with attach_to_ringbuf_stream(
|
||||
token_in,
|
||||
token_out
|
||||
) as transport:
|
||||
await ctx.started(msgs)
|
||||
|
||||
for msg in msgs:
|
||||
await transport.send(msg)
|
||||
|
||||
await transport.recv()
|
||||
|
||||
|
||||
def test_ringbuf_transport():
|
||||
|
||||
msg_amount_min = 100
|
||||
msg_amount_max = 1000
|
||||
|
||||
async def main():
|
||||
with tractor.ipc.open_ringbuf_pair(
|
||||
'test_ringbuf_transport'
|
||||
) as (token_0, token_1):
|
||||
async with (
|
||||
attach_to_ringbuf_stream(token_0, token_1) as transport,
|
||||
tractor.open_nursery() as an
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
'test_ringbuf_transport_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': token_0.fds + token_1.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
recv_p.open_context(
|
||||
child_transport_sender,
|
||||
msg_amount_min=msg_amount_min,
|
||||
msg_amount_max=msg_amount_max,
|
||||
token_in=token_1,
|
||||
token_out=token_0
|
||||
) as (ctx, msgs),
|
||||
):
|
||||
recv_msgs = []
|
||||
while len(recv_msgs) < len(msgs):
|
||||
recv_msgs.append(await transport.recv())
|
||||
|
||||
await transport.send(b'end')
|
||||
await recv_p.cancel_actor()
|
||||
assert recv_msgs == msgs
|
||||
|
||||
trio.run(main)
|
||||
|
|
|
@ -2,6 +2,10 @@ import os
|
|||
import random
|
||||
|
||||
|
||||
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||
return b''.join(str(i % 10).encode() for i in range(amount))
|
||||
|
||||
|
||||
def generate_sample_messages(
|
||||
amount: int,
|
||||
rand_min: int = 0,
|
||||
|
|
|
@ -39,12 +39,19 @@ if platform.system() == 'Linux':
|
|||
write_eventfd as write_eventfd,
|
||||
read_eventfd as read_eventfd,
|
||||
close_eventfd as close_eventfd,
|
||||
EFDReadCancelled as EFDReadCancelled,
|
||||
EventFD as EventFD,
|
||||
)
|
||||
|
||||
from ._ringbuf import (
|
||||
RBToken as RBToken,
|
||||
open_ringbuf as open_ringbuf,
|
||||
RingBuffSender as RingBuffSender,
|
||||
RingBuffReceiver as RingBuffReceiver,
|
||||
open_ringbuf as open_ringbuf
|
||||
open_ringbuf_pair as open_ringbuf_pair,
|
||||
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
|
||||
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
|
||||
attach_to_ringbuf_pair as attach_to_ringbuf_pair,
|
||||
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
|
||||
MsgpackRBStream as MsgpackRBStream
|
||||
)
|
||||
|
|
|
@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
|
|||
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
||||
|
||||
|
||||
class EFDReadCancelled(Exception):
|
||||
...
|
||||
|
||||
|
||||
class EventFD:
|
||||
'''
|
||||
Use a previously opened eventfd(2), meant to be used in
|
||||
|
@ -124,6 +128,7 @@ class EventFD:
|
|||
self._fd: int = fd
|
||||
self._omode: str = omode
|
||||
self._fobj = None
|
||||
self._cscope: trio.CancelScope | None = None
|
||||
|
||||
@property
|
||||
def fd(self) -> int | None:
|
||||
|
@ -133,17 +138,38 @@ class EventFD:
|
|||
return write_eventfd(self._fd, value)
|
||||
|
||||
async def read(self) -> int:
|
||||
return await trio.to_thread.run_sync(
|
||||
read_eventfd, self._fd,
|
||||
abandon_on_cancel=True
|
||||
)
|
||||
'''
|
||||
Async wrapper for `read_eventfd(self.fd)`
|
||||
|
||||
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
|
||||
in order to make it cancellable when `self.close()` is called.
|
||||
|
||||
'''
|
||||
self._cscope = trio.CancelScope()
|
||||
with self._cscope:
|
||||
return await trio.to_thread.run_sync(
|
||||
read_eventfd, self._fd,
|
||||
abandon_on_cancel=True
|
||||
)
|
||||
|
||||
if self._cscope.cancelled_caught:
|
||||
raise EFDReadCancelled
|
||||
|
||||
self._cscope = None
|
||||
|
||||
def open(self):
|
||||
self._fobj = os.fdopen(self._fd, self._omode)
|
||||
|
||||
def close(self):
|
||||
if self._fobj:
|
||||
self._fobj.close()
|
||||
try:
|
||||
self._fobj.close()
|
||||
|
||||
except OSError:
|
||||
...
|
||||
|
||||
if self._cscope:
|
||||
self._cscope.cancel()
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
|
|
|
@ -18,10 +18,22 @@ IPC Reliable RingBuffer implementation
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from contextlib import contextmanager as cm
|
||||
import struct
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator
|
||||
)
|
||||
from contextlib import (
|
||||
contextmanager as cm,
|
||||
asynccontextmanager as acm
|
||||
)
|
||||
from typing import (
|
||||
Any
|
||||
)
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
import trio
|
||||
from tricycle import BufferedReceiveStream
|
||||
from msgspec import (
|
||||
Struct,
|
||||
to_builtins
|
||||
|
@ -30,10 +42,16 @@ from msgspec import (
|
|||
from ._linux import (
|
||||
open_eventfd,
|
||||
close_eventfd,
|
||||
EFDReadCancelled,
|
||||
EventFD
|
||||
)
|
||||
from ._mp_bs import disable_mantracker
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
TransportClosed,
|
||||
InternalError
|
||||
)
|
||||
from tractor.ipc import MsgTransport
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
@ -41,16 +59,21 @@ log = get_logger(__name__)
|
|||
|
||||
disable_mantracker()
|
||||
|
||||
_DEFAULT_RB_SIZE = 10 * 1024
|
||||
|
||||
|
||||
class RBToken(Struct, frozen=True):
|
||||
'''
|
||||
RingBuffer token contains necesary info to open the two
|
||||
RingBuffer token contains necesary info to open the three
|
||||
eventfds and the shared memory
|
||||
|
||||
'''
|
||||
shm_name: str
|
||||
write_eventfd: int
|
||||
wrap_eventfd: int
|
||||
|
||||
write_eventfd: int # used to signal writer ptr advance
|
||||
wrap_eventfd: int # used to signal reader ready after wrap around
|
||||
eof_eventfd: int # used to signal writer closed
|
||||
|
||||
buf_size: int
|
||||
|
||||
def as_msg(self):
|
||||
|
@ -63,12 +86,29 @@ class RBToken(Struct, frozen=True):
|
|||
|
||||
return RBToken(**msg)
|
||||
|
||||
@property
|
||||
def fds(self) -> tuple[int, int, int]:
|
||||
'''
|
||||
Useful for `pass_fds` params
|
||||
|
||||
'''
|
||||
return (
|
||||
self.write_eventfd,
|
||||
self.wrap_eventfd,
|
||||
self.eof_eventfd
|
||||
)
|
||||
|
||||
|
||||
@cm
|
||||
def open_ringbuf(
|
||||
shm_name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
buf_size: int = _DEFAULT_RB_SIZE,
|
||||
) -> RBToken:
|
||||
'''
|
||||
Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
|
||||
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
|
||||
|
||||
'''
|
||||
shm = SharedMemory(
|
||||
name=shm_name,
|
||||
size=buf_size,
|
||||
|
@ -79,11 +119,27 @@ def open_ringbuf(
|
|||
shm_name=shm_name,
|
||||
write_eventfd=open_eventfd(),
|
||||
wrap_eventfd=open_eventfd(),
|
||||
eof_eventfd=open_eventfd(),
|
||||
buf_size=buf_size
|
||||
)
|
||||
yield token
|
||||
close_eventfd(token.write_eventfd)
|
||||
close_eventfd(token.wrap_eventfd)
|
||||
try:
|
||||
close_eventfd(token.write_eventfd)
|
||||
|
||||
except OSError:
|
||||
...
|
||||
|
||||
try:
|
||||
close_eventfd(token.wrap_eventfd)
|
||||
|
||||
except OSError:
|
||||
...
|
||||
|
||||
try:
|
||||
close_eventfd(token.eof_eventfd)
|
||||
|
||||
except OSError:
|
||||
...
|
||||
|
||||
finally:
|
||||
shm.unlink()
|
||||
|
@ -91,28 +147,36 @@ def open_ringbuf(
|
|||
|
||||
Buffer = bytes | bytearray | memoryview
|
||||
|
||||
'''
|
||||
IPC Reliable Ring Buffer
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, to signal writes to
|
||||
the reader and end of stream.
|
||||
|
||||
'''
|
||||
|
||||
|
||||
class RingBuffSender(trio.abc.SendStream):
|
||||
'''
|
||||
IPC Reliable Ring Buffer sender side implementation
|
||||
Ring Buffer sender side implementation
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
Do not use directly! manage with `attach_to_ringbuf_sender`
|
||||
after having opened a ringbuf context with `open_ringbuf`.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
is_ipc: bool = True
|
||||
cleanup: bool = False
|
||||
):
|
||||
self._token = RBToken.from_msg(token)
|
||||
self._shm: SharedMemory | None = None
|
||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||
self._ptr = 0
|
||||
|
||||
self._is_ipc = is_ipc
|
||||
self._cleanup = cleanup
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
@property
|
||||
|
@ -170,13 +234,21 @@ class RingBuffSender(trio.abc.SendStream):
|
|||
)
|
||||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
|
||||
async def aclose(self):
|
||||
if self._is_ipc:
|
||||
def close(self):
|
||||
self._eof_event.write(
|
||||
self._ptr if self._ptr > 0 else self.size
|
||||
)
|
||||
if self._cleanup:
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def aclose(self):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
@ -184,25 +256,27 @@ class RingBuffSender(trio.abc.SendStream):
|
|||
|
||||
class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||
'''
|
||||
IPC Reliable Ring Buffer receiver side implementation
|
||||
Ring Buffer receiver side implementation
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
Do not use directly! manage with `attach_to_ringbuf_receiver`
|
||||
after having opened a ringbuf context with `open_ringbuf`.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
is_ipc: bool = True
|
||||
cleanup: bool = True,
|
||||
):
|
||||
self._token = RBToken.from_msg(token)
|
||||
self._shm: SharedMemory | None = None
|
||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self._write_ptr = start_ptr
|
||||
self._is_ipc = is_ipc
|
||||
self._eof_event = EventFD(self._token.eof_eventfd, 'r')
|
||||
self._ptr: int = 0
|
||||
self._write_ptr: int = 0
|
||||
self._end_ptr: int = -1
|
||||
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -226,21 +300,71 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
|||
def wrap_fd(self) -> int:
|
||||
return self._wrap_event.fd
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> memoryview:
|
||||
async def _eof_monitor_task(self):
|
||||
'''
|
||||
Long running EOF event monitor, automatically run in bg by
|
||||
`attach_to_ringbuf_receiver` context manager, if EOF event
|
||||
is set its value will be the end pointer (highest valid
|
||||
index to be read from buf, after setting the `self._end_ptr`
|
||||
we close the write event which should cancel any blocked
|
||||
`self._write_event.read()`s on it.
|
||||
|
||||
'''
|
||||
try:
|
||||
self._end_ptr = await self._eof_event.read()
|
||||
self._write_event.close()
|
||||
|
||||
except EFDReadCancelled:
|
||||
...
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
'''
|
||||
Receive up to `max_bytes`, if no `max_bytes` is provided
|
||||
a reasonable default is used.
|
||||
|
||||
'''
|
||||
if max_bytes is None:
|
||||
max_bytes: int = _DEFAULT_RB_SIZE
|
||||
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
|
||||
# delta is remaining bytes we havent read
|
||||
delta = self._write_ptr - self._ptr
|
||||
if delta == 0:
|
||||
delta = await self._write_event.read()
|
||||
self._write_ptr += delta
|
||||
# we have read all we can, see if new data is available
|
||||
if self._end_ptr < 0:
|
||||
# if we havent been signaled about EOF yet
|
||||
try:
|
||||
delta = await self._write_event.read()
|
||||
self._write_ptr += delta
|
||||
|
||||
if isinstance(max_bytes, int):
|
||||
if max_bytes == 0:
|
||||
raise ValueError('if set, max_bytes must be > 0')
|
||||
delta = min(delta, max_bytes)
|
||||
except EFDReadCancelled:
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
# this means writer signaled EOF
|
||||
if self._end_ptr > 0:
|
||||
# final self._write_ptr modification and recalculate delta
|
||||
self._write_ptr = self._end_ptr
|
||||
delta = self._end_ptr - self._ptr
|
||||
|
||||
else:
|
||||
# shouldnt happen cause self._eof_monitor_task always sets
|
||||
# self._end_ptr before closing self._write_event
|
||||
raise InternalError(
|
||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||
)
|
||||
|
||||
else:
|
||||
# no more bytes to read and self._end_ptr set, EOF reached
|
||||
return b''
|
||||
|
||||
# dont overflow caller
|
||||
delta = min(delta, max_bytes)
|
||||
|
||||
target_ptr = self._ptr + delta
|
||||
|
||||
# fetch next segment and advance ptr
|
||||
segment = self._shm.buf[self._ptr:target_ptr]
|
||||
segment = bytes(self._shm.buf[self._ptr:target_ptr])
|
||||
self._ptr = target_ptr
|
||||
|
||||
if self._ptr == self.size:
|
||||
|
@ -259,13 +383,284 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
|||
)
|
||||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
|
||||
async def aclose(self):
|
||||
if self._is_ipc:
|
||||
def close(self):
|
||||
if self._cleanup:
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def aclose(self):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_receiver(
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
):
|
||||
'''
|
||||
Instantiate a RingBuffReceiver from a previously opened
|
||||
RBToken.
|
||||
|
||||
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBuffReceiver(
|
||||
token,
|
||||
cleanup=cleanup
|
||||
) as receiver
|
||||
):
|
||||
n.start_soon(receiver._eof_monitor_task)
|
||||
yield receiver
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_sender(
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
):
|
||||
'''
|
||||
Instantiate a RingBuffSender from a previously opened
|
||||
RBToken.
|
||||
|
||||
'''
|
||||
async with RingBuffSender(
|
||||
token,
|
||||
cleanup=cleanup
|
||||
) as sender:
|
||||
yield sender
|
||||
|
||||
|
||||
@cm
|
||||
def open_ringbuf_pair(
|
||||
name: str,
|
||||
buf_size: int = _DEFAULT_RB_SIZE
|
||||
):
|
||||
'''
|
||||
Handle resources for a ringbuf pair to be used for
|
||||
bidirectional messaging.
|
||||
|
||||
'''
|
||||
with (
|
||||
open_ringbuf(
|
||||
name + '.pair0',
|
||||
buf_size=buf_size
|
||||
) as token_0,
|
||||
|
||||
open_ringbuf(
|
||||
name + '.pair1',
|
||||
buf_size=buf_size
|
||||
) as token_1
|
||||
):
|
||||
yield token_0, token_1
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_pair(
|
||||
token_in: RBToken,
|
||||
token_out: RBToken,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True
|
||||
):
|
||||
'''
|
||||
Instantiate a trio.StapledStream from a previously opened
|
||||
ringbuf pair.
|
||||
|
||||
'''
|
||||
async with (
|
||||
attach_to_ringbuf_receiver(
|
||||
token_in,
|
||||
cleanup=cleanup_in
|
||||
) as receiver,
|
||||
attach_to_ringbuf_sender(
|
||||
token_out,
|
||||
cleanup=cleanup_out
|
||||
) as sender,
|
||||
):
|
||||
yield trio.StapledStream(sender, receiver)
|
||||
|
||||
|
||||
class MsgpackRBStream(MsgTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.StapledStream
|
||||
):
|
||||
self.stream = stream
|
||||
|
||||
# create read loop intance
|
||||
self._aiter_pkts = self._iter_packets()
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
self.drained: list[dict] = []
|
||||
|
||||
self.recv_stream = BufferedReceiveStream(
|
||||
transport_stream=stream
|
||||
)
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||
stream using the current task's `MsgCodec`.
|
||||
|
||||
This is a streaming routine implemented as an async generator
|
||||
func (which was the original design, but could be changed?)
|
||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||
it is assigned to the `._aiter_pkts` attr.
|
||||
|
||||
'''
|
||||
|
||||
while True:
|
||||
try:
|
||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||
except (
|
||||
ValueError,
|
||||
ConnectionResetError,
|
||||
|
||||
# not sure entirely why we need this but without it we
|
||||
# seem to be getting racy failures here on
|
||||
# arbiter/registry name subs..
|
||||
trio.BrokenResourceError,
|
||||
|
||||
) as trans_err:
|
||||
|
||||
loglevel = 'transport'
|
||||
match trans_err:
|
||||
# case (
|
||||
# ConnectionResetError()
|
||||
# ):
|
||||
# loglevel = 'transport'
|
||||
|
||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||
# seems to raise a 0-bytes-read?
|
||||
case ValueError() if (
|
||||
'unclean EOF' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# peer actor (task) prolly shutdown quickly due
|
||||
# to cancellation
|
||||
case trio.BrokenResourceError() if (
|
||||
'Connection reset by peer' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# unless the disconnect condition falls under "a
|
||||
# normal operation breakage" we usualy console warn
|
||||
# about it.
|
||||
case _:
|
||||
loglevel: str = 'warning'
|
||||
|
||||
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already closed by peer\n'
|
||||
f'x)> {type(trans_err)}\n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel=loglevel,
|
||||
) from trans_err
|
||||
|
||||
# XXX definitely can happen if transport is closed
|
||||
# manually by another `trio.lowlevel.Task` in the
|
||||
# same actor; we use this in some simulated fault
|
||||
# testing for ex, but generally should never happen
|
||||
# under normal operation!
|
||||
#
|
||||
# NOTE: as such we always re-raise this error from the
|
||||
# RPC msg loop!
|
||||
except trio.ClosedResourceError as closure_err:
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already manually closed locally?\n'
|
||||
f'x)> {type(closure_err)} \n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel='error',
|
||||
raise_on_report=(
|
||||
closure_err.args[0] == 'another task closed this fd'
|
||||
or
|
||||
closure_err.args[0] in ['another task closed this fd']
|
||||
),
|
||||
) from closure_err
|
||||
|
||||
# graceful EOF disconnect
|
||||
if header == b'':
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already gracefully closed\n'
|
||||
f')>\n'
|
||||
f'|_{self}\n'
|
||||
),
|
||||
loglevel='transport',
|
||||
# cause=??? # handy or no?
|
||||
)
|
||||
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
|
||||
log.transport(f'received header {size}') # type: ignore
|
||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.transport(f"received {msg_bytes}") # type: ignore
|
||||
yield msg_bytes
|
||||
|
||||
async def send(
|
||||
self,
|
||||
msg: bytes,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack encoded py-object-blob-as-msg.
|
||||
|
||||
'''
|
||||
async with self._send_lock:
|
||||
size: bytes = struct.pack("<I", len(msg))
|
||||
return await self.stream.send_all(size + msg)
|
||||
|
||||
async def recv(self) -> Any:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
except TransportClosed:
|
||||
for msg in self.drained:
|
||||
yield msg
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter_pkts
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_stream(
|
||||
token_in: RBToken,
|
||||
token_out: RBToken,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True
|
||||
):
|
||||
'''
|
||||
Wrap a ringbuf trio.StapledStream in a MsgpackRBStream
|
||||
|
||||
'''
|
||||
async with attach_to_ringbuf_pair(
|
||||
token_in,
|
||||
token_out,
|
||||
cleanup_in=cleanup_in,
|
||||
cleanup_out=cleanup_out
|
||||
) as stream:
|
||||
yield MsgpackRBStream(stream)
|
||||
|
|
|
@ -26,7 +26,6 @@ import struct
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Type,
|
||||
)
|
||||
|
||||
import msgspec
|
||||
|
|
|
@ -41,10 +41,10 @@ class MsgTransport(Protocol[MsgType]):
|
|||
# eventual msg definition/types?
|
||||
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||
|
||||
stream: trio.SocketStream
|
||||
stream: trio.abc.Stream
|
||||
drained: list[MsgType]
|
||||
|
||||
def __init__(self, stream: trio.SocketStream) -> None:
|
||||
def __init__(self, stream: trio.abc.Stream) -> None:
|
||||
...
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
|
|
Loading…
Reference in New Issue