Compare commits

...

3 Commits

Author SHA1 Message Date
Guillermo Rodriguez 41e84cc701
Move tractor._shm to tractor.ipc._shm 2025-03-13 21:02:16 -03:00
Guillermo Rodriguez 7b1f42942e
move tractor._ipc.py into tractor.ipc._chan.py 2025-03-13 21:02:16 -03:00
Guillermo Rodriguez c5ae3a767e
General improvements
EventFD class now expects the fd to already be init with open_eventfd
RingBuff Sender and Receiver fully manage SharedMemory and EventFD lifecycles, no aditional ctx mngrs needed
Separate ring buf tests into its own test bed
Add parametrization to test and cancellation
Add docstrings
Add simple testing data gen module .samples
2025-03-13 21:02:14 -03:00
17 changed files with 386 additions and 177 deletions

View File

@ -0,0 +1,212 @@
import time
import trio
import pytest
import tractor
from tractor.ipc._shm import (
EFD_NONBLOCK,
open_eventfd,
RingBuffSender,
RingBuffReceiver
)
from tractor._testing.samples import generate_sample_messages
@tractor.context
async def child_read_shm(
ctx: tractor.Context,
msg_amount: int,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int,
total_bytes: int,
flags: int = 0,
) -> None:
recvd_bytes = 0
await ctx.started()
start_ts = time.time()
async with RingBuffReceiver(
shm_key,
write_eventfd,
wrap_eventfd,
buf_size=buf_size,
flags=flags
) as receiver:
while recvd_bytes < total_bytes:
msg = await receiver.receive_some()
recvd_bytes += len(msg)
# make sure we dont hold any memoryviews
# before the ctx manager aclose()
msg = None
end_ts = time.time()
elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000)
print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
@tractor.context
async def child_write_shm(
ctx: tractor.Context,
msg_amount: int,
rand_min: int,
rand_max: int,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int,
) -> None:
msgs, total_bytes = generate_sample_messages(
msg_amount,
rand_min=rand_min,
rand_max=rand_max,
)
await ctx.started(total_bytes)
async with RingBuffSender(
shm_key,
write_eventfd,
wrap_eventfd,
buf_size=buf_size
) as sender:
for msg in msgs:
await sender.send_all(msg)
@pytest.mark.parametrize(
'msg_amount,rand_min,rand_max,buf_size',
[
# simple case, fixed payloads, large buffer
(100_000, 0, 0, 10 * 1024),
# guaranteed wrap around on every write
(100, 10 * 1024, 20 * 1024, 10 * 1024),
# large payload size, but large buffer
(10_000, 256 * 1024, 512 * 1024, 10 * 1024 * 1024)
],
ids=[
'fixed_payloads_large_buffer',
'wrap_around_every_write',
'large_payloads_large_buffer',
]
)
def test_ring_buff(
msg_amount: int,
rand_min: int,
rand_max: int,
buf_size: int
):
write_eventfd = open_eventfd()
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_buff'
common_kwargs = {
'msg_amount': msg_amount,
'shm_key': shm_key,
'write_eventfd': write_eventfd,
'wrap_eventfd': wrap_eventfd,
'buf_size': buf_size
}
async def main():
async with tractor.open_nursery() as an:
send_p = await an.start_actor(
'ring_sender',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
recv_p = await an.start_actor(
'ring_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
send_p.open_context(
child_write_shm,
rand_min=rand_min,
rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes),
recv_p.open_context(
child_read_shm,
**common_kwargs,
total_bytes=total_bytes,
) as (sctx, _sent),
):
await recv_p.result()
await send_p.cancel_actor()
await recv_p.cancel_actor()
trio.run(main)
@tractor.context
async def child_blocked_receiver(
ctx: tractor.Context,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
flags: int = 0
):
async with RingBuffReceiver(
shm_key,
write_eventfd,
wrap_eventfd,
flags=flags
) as receiver:
await ctx.started()
await receiver.receive_some()
def test_ring_reader_cancel():
flags = EFD_NONBLOCK
write_eventfd = open_eventfd(flags=flags)
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_cancel'
async def main():
async with (
tractor.open_nursery() as an,
RingBuffSender(
shm_key,
write_eventfd,
wrap_eventfd,
) as _sender,
):
recv_p = await an.start_actor(
'ring_blocked_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
recv_p.open_context(
child_blocked_receiver,
write_eventfd=write_eventfd,
wrap_eventfd=wrap_eventfd,
shm_key=shm_key,
flags=flags
) as (sctx, _sent),
):
await trio.sleep(1)
await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main)

View File

@ -2,19 +2,15 @@
Shared mem primitives and APIs.
"""
import time
import uuid
import string
import random
# import numpy
import pytest
import trio
import tractor
from tractor._shm import (
from tractor.ipc._shm import (
open_shm_list,
attach_shm_list,
EventFD, open_ringbuffer_sender, open_ringbuffer_receiver,
)
@ -169,79 +165,3 @@ def test_parent_writer_child_reader(
await portal.cancel_actor()
trio.run(main)
def random_string(size=256):
return ''.join(random.choice(string.ascii_lowercase) for i in range(size))
async def child_read_shm(
msg_amount: int,
key: str,
write_event_fd: int,
wrap_event_fd: int,
max_bytes: int,
) -> None:
log = tractor.log.get_console_log(level='info')
recvd_msgs = 0
start_ts = time.time()
async with open_ringbuffer_receiver(
write_event_fd,
wrap_event_fd,
key,
max_bytes=max_bytes
) as receiver:
while recvd_msgs < msg_amount:
msg = await receiver.receive_some()
msgs = bytes(msg).split(b'\n')
first = msgs[0]
last = msgs[-2]
log.info((receiver.ptr - len(msg), receiver.ptr, first[:10], last[:10]))
recvd_msgs += len(msgs)
end_ts = time.time()
elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000)
log.info(f'elapsed ms: {elapsed_ms}')
log.info(f'msg/sec: {int(msg_amount / elapsed):,}')
log.info(f'bytes/sec: {int(max_bytes / elapsed):,}')
def test_ring_buff():
log = tractor.log.get_console_log(level='info')
msg_amount = 100_000
log.info(f'generating {msg_amount} messages...')
msgs = [
f'[{i:08}]: {random_string()}\n'.encode('utf-8')
for i in range(msg_amount)
]
buf_size = sum((len(m) for m in msgs))
log.info(f'done! buffer size: {buf_size}')
async def main():
with (
EventFD(initval=0) as write_event,
EventFD(initval=0) as wrap_event,
):
async with (
tractor.open_nursery() as an,
open_ringbuffer_sender(
write_event.fd,
wrap_event.fd,
max_bytes=buf_size
) as sender
):
await an.run_in_actor(
child_read_shm,
msg_amount=msg_amount,
key=sender.key,
write_event_fd=write_event.fd,
wrap_event_fd=wrap_event.fd,
max_bytes=buf_size,
proc_kwargs={
'pass_fds': (write_event.fd, wrap_event.fd)
}
)
for msg in msgs:
await sender.send_all(msg)
trio.run(main)

View File

@ -62,6 +62,6 @@ from ._root import (
run_daemon as run_daemon,
open_root_actor as open_root_actor,
)
from ._ipc import Channel as Channel
from .ipc import Channel as Channel
from ._portal import Portal as Portal
from ._runtime import Actor as Actor

View File

@ -85,7 +85,7 @@ from .msg import (
pretty_struct,
_ops as msgops,
)
from ._ipc import (
from .ipc import (
Channel,
)
from ._streaming import (
@ -101,7 +101,7 @@ from ._state import (
if TYPE_CHECKING:
from ._portal import Portal
from ._runtime import Actor
from ._ipc import MsgTransport
from .ipc import MsgTransport
from .devx._frame_stack import (
CallerInfo,
)

View File

@ -29,7 +29,7 @@ from contextlib import asynccontextmanager as acm
from tractor.log import get_logger
from .trionics import gather_contexts
from ._ipc import _connect_chan, Channel
from .ipc import _connect_chan, Channel
from ._portal import (
Portal,
open_portal,

View File

@ -64,7 +64,7 @@ if TYPE_CHECKING:
from ._context import Context
from .log import StackLevelAdapter
from ._stream import MsgStream
from ._ipc import Channel
from .ipc import Channel
log = get_logger('tractor')

View File

@ -43,7 +43,7 @@ from .trionics import maybe_open_nursery
from ._state import (
current_actor,
)
from ._ipc import Channel
from .ipc import Channel
from .log import get_logger
from .msg import (
# Error,

View File

@ -43,7 +43,7 @@ from .devx import _debug
from . import _spawn
from . import _state
from . import log
from ._ipc import _connect_chan
from .ipc import _connect_chan
from ._exceptions import is_multi_cancelled

View File

@ -42,7 +42,7 @@ from trio import (
TaskStatus,
)
from ._ipc import Channel
from .ipc import Channel
from ._context import (
Context,
)

View File

@ -73,7 +73,7 @@ from tractor.msg import (
pretty_struct,
types as msgtypes,
)
from ._ipc import Channel
from .ipc import Channel
from ._context import (
mk_context,
Context,

View File

@ -54,7 +54,7 @@ from tractor.msg import (
if TYPE_CHECKING:
from ._runtime import Actor
from ._context import Context
from ._ipc import Channel
from .ipc import Channel
log = get_logger(__name__)

View File

@ -0,0 +1,35 @@
import os
import random
def generate_sample_messages(
amount: int,
rand_min: int = 0,
rand_max: int = 0,
silent: bool = False
) -> tuple[list[bytes], int]:
msgs = []
size = 0
if not silent:
print(f'\ngenerating {amount} messages...')
for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8')
if rand_max > 0:
msg += os.urandom(
random.randint(rand_min, rand_max))
size += len(msg)
msgs.append(msg)
if not silent and i and i % 10_000 == 0:
print(f'{i} generated')
if not silent:
print(f'done, {size:,} bytes in total')
return msgs, size

View File

@ -91,7 +91,7 @@ from tractor._state import (
if TYPE_CHECKING:
from trio.lowlevel import Task
from threading import Thread
from tractor._ipc import Channel
from tractor.ipc import Channel
from tractor._runtime import (
Actor,
)

View File

@ -0,0 +1,11 @@
from ._chan import (
_connect_chan,
MsgTransport,
Channel
)
__all__ = [
'_connect_chan',
'MsgTransport',
'Channel',
]

View File

@ -36,7 +36,7 @@ from multiprocessing.shared_memory import (
from msgspec import Struct, to_builtins
import tractor
from .log import get_logger
from tractor.log import get_logger
_USE_POSIX = getattr(shm, '_USE_POSIX', False)
@ -837,8 +837,6 @@ def attach_shm_list(
if platform.system() == 'Linux':
import os
import errno
import string
import random
from contextlib import asynccontextmanager as acm
import cffi
@ -862,19 +860,21 @@ if platform.system() == 'Linux':
'''
)
# Open the default dynamic library (essentially 'libc' in most cases)
C = ffi.dlopen(None)
# Constants from <sys/eventfd.h>, if needed.
EFD_SEMAPHORE = 1 << 0 # 0x1
EFD_CLOEXEC = 1 << 1 # 0x2
EFD_NONBLOCK = 1 << 2 # 0x4
EFD_SEMAPHORE = 1
EFD_CLOEXEC = 0o2000000
EFD_NONBLOCK = 0o4000
def open_eventfd(initval: int = 0, flags: int = 0) -> int:
'''
Open an eventfd with the given initial value and flags.
Returns the file descriptor on success, otherwise raises OSError.
'''
fd = C.eventfd(initval, flags)
if fd < 0:
@ -884,6 +884,7 @@ if platform.system() == 'Linux':
def write_eventfd(fd: int, value: int) -> int:
'''
Write a 64-bit integer (uint64_t) to the eventfd's counter.
'''
# Create a uint64_t* in C, store `value`
data_ptr = ffi.new('uint64_t *', value)
@ -899,6 +900,7 @@ if platform.system() == 'Linux':
'''
Read a 64-bit integer (uint64_t) from the eventfd, returning the value.
Reading resets the counter to 0 (unless using EFD_SEMAPHORE).
'''
# Allocate an 8-byte buffer in C for reading
buf = ffi.new('char[]', 8)
@ -914,6 +916,7 @@ if platform.system() == 'Linux':
def close_eventfd(fd: int) -> int:
'''
Close the eventfd.
'''
ret = C.close(fd)
if ret < 0:
@ -921,17 +924,19 @@ if platform.system() == 'Linux':
class EventFD:
'''
Use a previously opened eventfd(2), meant to be used in
sub-actors after root actor opens the eventfds then passes
them through pass_fds
'''
def __init__(
self,
initval: int = 0,
flags: int = 0,
fd: int | None = None,
omode: str = 'r'
fd: int,
omode: str
):
self._initval: int = initval
self._flags: int = flags
self._fd: int | None = fd
self._fd: int = fd
self._omode: str = omode
self._fobj = None
@ -943,23 +948,15 @@ if platform.system() == 'Linux':
return write_eventfd(self._fd, value)
async def read(self) -> int:
#TODO: how to handle signals?
return await trio.to_thread.run_sync(read_eventfd, self._fd)
def open(self):
if not self._fd:
self._fd = open_eventfd(
initval=self._initval, flags=self._flags)
else:
self._fobj = os.fdopen(self._fd, self._omode)
self._fobj = os.fdopen(self._fd, self._omode)
def close(self):
if self._fobj:
self._fobj.close()
return
if self._fd:
close_eventfd(self._fd)
def __enter__(self):
self.open()
@ -970,18 +967,34 @@ if platform.system() == 'Linux':
class RingBuffSender(trio.abc.SendStream):
'''
IPC Reliable Ring Buffer sender side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
TODO: if blocked on wrap around event wait it will not respond
to signals, fix soon TM
'''
def __init__(
self,
shm: SharedMemory,
write_event: EventFD,
wrap_event: EventFD,
start_ptr: int = 0
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0,
buf_size: int = 10 * 1024,
clean_shm_on_exit: bool = True
):
self._shm: SharedMemory = shm
self._write_event = write_event
self._wrap_event = wrap_event
self._shm = SharedMemory(
name=shm_key,
size=buf_size,
create=True
)
self._write_event = EventFD(write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r')
self._ptr = start_ptr
self.clean_shm_on_exit = clean_shm_on_exit
@property
def key(self) -> str:
@ -1004,25 +1017,37 @@ if platform.system() == 'Linux':
return self._wrap_event.fd
async def send_all(self, data: bytes | bytearray | memoryview):
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
if target_ptr > self.size:
while target_ptr > self.size:
# write all bytes that fit
remaining = self.size - self.ptr
self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around
self._write_event.write(remaining)
await self._wrap_event.read()
# wrap around and trim already written bytes
self._ptr = 0
data = data[remaining:]
target_ptr = self._ptr + len(data)
# remaining data fits on buffer
self._shm.buf[self.ptr:target_ptr] = data
self._write_event.write(len(data))
self._ptr = target_ptr
async def wait_send_all_might_not_block(self):
...
raise NotImplementedError
async def aclose(self):
...
self._write_event.close()
self._wrap_event.close()
if self.clean_shm_on_exit:
self._shm.unlink()
else:
self._shm.close()
async def __aenter__(self):
self._write_event.open()
@ -1034,18 +1059,37 @@ if platform.system() == 'Linux':
class RingBuffReceiver(trio.abc.ReceiveStream):
'''
IPC Reliable Ring Buffer receiver side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
Unless eventfd(2) object is opened with EFD_NONBLOCK flag,
calls to `receive_some` will block the signal handling,
on the main thread, for now solution is using polling,
working on a way to unblock GIL during read(2) to allow
signal processing on the main thread.
'''
def __init__(
self,
shm: SharedMemory,
write_event: EventFD,
wrap_event: EventFD,
start_ptr: int = 0
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0,
buf_size: int = 10 * 1024,
flags: int = 0
):
self._shm: SharedMemory = shm
self._write_event = write_event
self._wrap_event = wrap_event
self._shm = SharedMemory(
name=shm_key,
size=buf_size,
create=False
)
self._write_event = EventFD(write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r')
self._ptr = start_ptr
self._flags = flags
@property
def key(self) -> str:
@ -1067,18 +1111,44 @@ if platform.system() == 'Linux':
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def receive_some(self, max_bytes: int | None = None) -> bytes:
delta = await self._write_event.read()
async def receive_some(
self,
max_bytes: int | None = None,
nb_timeout: float = 0.1
) -> memoryview:
# if non blocking eventfd enabled, do polling
# until next write, this allows signal handling
if self._flags | EFD_NONBLOCK:
delta = None
while delta is None:
try:
delta = await self._write_event.read()
except OSError as e:
if e.errno == 'EAGAIN':
continue
raise e
else:
delta = await self._write_event.read()
# fetch next segment and advance ptr
next_ptr = self._ptr + delta
segment = bytes(self._shm.buf[self._ptr:next_ptr])
segment = self._shm.buf[self._ptr:next_ptr]
self._ptr = next_ptr
if self.ptr == self.size:
# reached the end, signal wrap around
self._ptr = 0
self._wrap_event.write(1)
return segment
async def aclose(self):
...
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self._write_event.open()
@ -1087,42 +1157,3 @@ if platform.system() == 'Linux':
async def __aexit__(self, exc_type, exc_value, traceback):
await self.aclose()
@acm
async def open_ringbuffer_sender(
write_event_fd: int,
wrap_event_fd: int,
key: str | None = None,
max_bytes: int = 10 * 1024,
start_ptr: int = 0,
) -> RingBuffSender:
if not key:
key: str = ''.join(random.choice(string.ascii_lowercase) for i in range(32))
shm = SharedMemory(
name=key,
size=max_bytes,
create=True
)
async with RingBuffSender(
shm, EventFD(fd=write_event_fd, omode='w'), EventFD(fd=wrap_event_fd), start_ptr=start_ptr
) as s:
yield s
@acm
async def open_ringbuffer_receiver(
write_event_fd: int,
wrap_event_fd: int,
key: str,
max_bytes: int = 10 * 1024,
start_ptr: int = 0,
) -> RingBuffSender:
shm = SharedMemory(
name=key,
size=max_bytes,
create=False
)
async with RingBuffReceiver(
shm, EventFD(fd=write_event_fd), EventFD(fd=wrap_event_fd, omode='w'), start_ptr=start_ptr
) as r:
yield r

View File

@ -92,7 +92,7 @@ class StackLevelAdapter(LoggerAdapter):
) -> None:
'''
IPC transport level msg IO; generally anything below
`._ipc.Channel` and friends.
`.ipc.Channel` and friends.
'''
return self.log(5, msg)
@ -285,7 +285,7 @@ def get_logger(
# NOTE: for handling for modules that use ``get_logger(__name__)``
# we make the following stylistic choice:
# - always avoid duplicate project-package token
# in msg output: i.e. tractor.tractor _ipc.py in header
# in msg output: i.e. tractor.tractor.ipc._chan.py in header
# looks ridiculous XD
# - never show the leaf module name in the {name} part
# since in python the {filename} is always this same