diff --git a/tests/test_eventfd.py b/tests/test_eventfd.py new file mode 100644 index 00000000..3432048b --- /dev/null +++ b/tests/test_eventfd.py @@ -0,0 +1,66 @@ +import trio +import pytest +from tractor.linux.eventfd import ( + open_eventfd, + EFDReadCancelled, + EventFD +) + + +def test_read_cancellation(): + ''' + Ensure EventFD.read raises EFDReadCancelled if EventFD.close() + is called. + + ''' + fd = open_eventfd() + + async def bg_read(event: EventFD): + with pytest.raises(EFDReadCancelled): + await event.read() + + async def main(): + async with trio.open_nursery() as n: + with ( + EventFD(fd, 'w') as event, + trio.fail_after(3) + ): + n.start_soon(bg_read, event) + await trio.sleep(0.2) + event.close() + + trio.run(main) + + +def test_read_trio_semantics(): + ''' + Ensure EventFD.read raises trio.ClosedResourceError and + trio.BusyResourceError. + + ''' + + fd = open_eventfd() + + async def bg_read(event: EventFD): + try: + await event.read() + + except EFDReadCancelled: + ... + + async def main(): + async with trio.open_nursery() as n: + + # start background read and attempt + # foreground read, should be busy + with EventFD(fd, 'w') as event: + n.start_soon(bg_read, event) + await trio.sleep(0.2) + with pytest.raises(trio.BusyResourceError): + await event.read() + + # attempt read after close + with pytest.raises(trio.ClosedResourceError): + await event.read() + + trio.run(main) diff --git a/tests/test_ring_pubsub.py b/tests/test_ring_pubsub.py new file mode 100644 index 00000000..b3b0dade --- /dev/null +++ b/tests/test_ring_pubsub.py @@ -0,0 +1,185 @@ +from typing import AsyncContextManager +from contextlib import asynccontextmanager as acm + +import trio +import pytest +import tractor + +from tractor.trionics import gather_contexts + +from tractor.ipc._ringbuf import open_ringbufs +from tractor.ipc._ringbuf._pubsub import ( + open_ringbuf_publisher, + open_ringbuf_subscriber, + get_publisher, + get_subscriber, + open_pub_channel_at, + open_sub_channel_at +) + + +log = tractor.log.get_console_log(level='info') + + +@tractor.context +async def publish_range( + ctx: tractor.Context, + size: int +): + pub = get_publisher() + await ctx.started() + for i in range(size): + await pub.send(i.to_bytes(4)) + log.info(f'sent {i}') + + await pub.flush() + + log.info('range done') + + +@tractor.context +async def subscribe_range( + ctx: tractor.Context, + size: int +): + sub = get_subscriber() + await ctx.started() + + for i in range(size): + recv = int.from_bytes(await sub.receive()) + if recv != i: + raise AssertionError( + f'received: {recv} expected: {i}' + ) + + log.info(f'received: {recv}') + + log.info('range done') + + +@tractor.context +async def subscriber_child(ctx: tractor.Context): + try: + async with open_ringbuf_subscriber(guarantee_order=True): + await ctx.started() + await trio.sleep_forever() + + finally: + log.info('subscriber exit') + + +@tractor.context +async def publisher_child( + ctx: tractor.Context, + batch_size: int +): + try: + async with open_ringbuf_publisher( + guarantee_order=True, + batch_size=batch_size + ): + await ctx.started() + await trio.sleep_forever() + + finally: + log.info('publisher exit') + + +@acm +async def open_pubsub_test_actors( + + ring_names: list[str], + size: int, + batch_size: int + +) -> AsyncContextManager[tuple[tractor.Portal, tractor.Portal]]: + + with trio.fail_after(5): + async with tractor.open_nursery( + enable_modules=[ + 'tractor.linux._fdshare' + ] + ) as an: + modules = [ + __name__, + 'tractor.linux._fdshare', + 'tractor.ipc._ringbuf._pubsub' + ] + sub_portal = await an.start_actor( + 'sub', + enable_modules=modules + ) + pub_portal = await an.start_actor( + 'pub', + enable_modules=modules + ) + + async with ( + sub_portal.open_context(subscriber_child) as (long_rctx, _), + pub_portal.open_context( + publisher_child, + batch_size=batch_size + ) as (long_sctx, _), + + open_ringbufs(ring_names) as tokens, + + gather_contexts([ + open_sub_channel_at('sub', ring) + for ring in tokens + ]), + gather_contexts([ + open_pub_channel_at('pub', ring) + for ring in tokens + ]), + sub_portal.open_context(subscribe_range, size=size) as (rctx, _), + pub_portal.open_context(publish_range, size=size) as (sctx, _) + ): + yield + + await rctx.wait_for_result() + await sctx.wait_for_result() + + await long_sctx.cancel() + await long_rctx.cancel() + + await an.cancel() + + +@pytest.mark.parametrize( + ('ring_names', 'size', 'batch_size'), + [ + ( + ['ring-first'], + 100, + 1 + ), + ( + ['ring-first'], + 69, + 1 + ), + ( + [f'multi-ring-{i}' for i in range(3)], + 1000, + 100 + ), + ], + ids=[ + 'simple', + 'redo-simple', + 'multi-ring', + ] +) +def test_pubsub( + request, + ring_names: list[str], + size: int, + batch_size: int +): + async def main(): + async with open_pubsub_test_actors( + ring_names, size, batch_size + ): + ... + + trio.run(main) diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py index 0d3b420b..1f6a1927 100644 --- a/tests/test_ringbuf.py +++ b/tests/test_ringbuf.py @@ -1,4 +1,5 @@ import time +import hashlib import trio import pytest @@ -6,36 +7,45 @@ import pytest import tractor from tractor.ipc._ringbuf import ( open_ringbuf, + open_ringbuf_pair, + attach_to_ringbuf_receiver, + attach_to_ringbuf_sender, + attach_to_ringbuf_channel, RBToken, - RingBuffSender, - RingBuffReceiver ) from tractor._testing.samples import ( - generate_sample_messages, + generate_single_byte_msgs, + RandomBytesGenerator ) -# in case you don't want to melt your cores, uncomment dis! -pytestmark = pytest.mark.skip - @tractor.context async def child_read_shm( ctx: tractor.Context, - msg_amount: int, token: RBToken, - total_bytes: int, -) -> None: - recvd_bytes = 0 - await ctx.started() - start_ts = time.time() - async with RingBuffReceiver(token) as receiver: - while recvd_bytes < total_bytes: - msg = await receiver.receive_some() - recvd_bytes += len(msg) +) -> str: + ''' + Sub-actor used in `test_ringbuf`. - # make sure we dont hold any memoryviews - # before the ctx manager aclose() - msg = None + Attach to a ringbuf and receive all messages until end of stream. + Keep track of how many bytes received and also calculate + sha256 of the whole byte stream. + + Calculate and print performance stats, finally return calculated + hash. + + ''' + await ctx.started() + print('reader started') + msg_amount = 0 + recvd_bytes = 0 + recvd_hash = hashlib.sha256() + start_ts = time.time() + async with attach_to_ringbuf_receiver(token) as receiver: + async for msg in receiver: + msg_amount += 1 + recvd_hash.update(msg) + recvd_bytes += len(msg) end_ts = time.time() elapsed = end_ts - start_ts @@ -44,6 +54,10 @@ async def child_read_shm( print(f'\n\telapsed ms: {elapsed_ms}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') + print(f'\treceived msgs: {msg_amount:,}') + print(f'\treceived bytes: {recvd_bytes:,}') + + return recvd_hash.hexdigest() @tractor.context @@ -52,17 +66,37 @@ async def child_write_shm( msg_amount: int, rand_min: int, rand_max: int, - token: RBToken, + buf_size: int ) -> None: - msgs, total_bytes = generate_sample_messages( + ''' + Sub-actor used in `test_ringbuf` + + Generate `msg_amount` payloads with + `random.randint(rand_min, rand_max)` random bytes at the end, + Calculate sha256 hash and send it to parent on `ctx.started`. + + Attach to ringbuf and send all generated messages. + + ''' + rng = RandomBytesGenerator( msg_amount, rand_min=rand_min, rand_max=rand_max, ) - await ctx.started(total_bytes) - async with RingBuffSender(token) as sender: - for msg in msgs: - await sender.send_all(msg) + async with ( + open_ringbuf('test_ringbuf', buf_size=buf_size) as token, + attach_to_ringbuf_sender(token) as sender + ): + await ctx.started(token) + print('writer started') + for msg in rng: + await sender.send(msg) + + if rng.msgs_generated % rng.recommended_log_interval == 0: + print(f'wrote {rng.msgs_generated} msgs') + + print('writer exit') + return rng.hexdigest @pytest.mark.parametrize( @@ -89,84 +123,91 @@ def test_ringbuf( rand_max: int, buf_size: int ): + ''' + - Open a new ring buf on root actor + - Open `child_write_shm` ctx in sub-actor which will generate a + random payload and send its hash on `ctx.started`, finally sending + the payload through the stream. + - Open `child_read_shm` ctx in sub-actor which will receive the + payload, calculate perf stats and return the hash. + - Compare both hashes + + ''' async def main(): - with open_ringbuf( - 'test_ringbuf', - buf_size=buf_size - ) as token: - proc_kwargs = { - 'pass_fds': (token.write_eventfd, token.wrap_eventfd) - } + async with tractor.open_nursery() as an: + send_p = await an.start_actor( + 'ring_sender', + enable_modules=[ + __name__, + 'tractor.linux._fdshare' + ], + ) + recv_p = await an.start_actor( + 'ring_receiver', + enable_modules=[ + __name__, + 'tractor.linux._fdshare' + ], + ) + async with ( + send_p.open_context( + child_write_shm, + msg_amount=msg_amount, + rand_min=rand_min, + rand_max=rand_max, + buf_size=buf_size + ) as (sctx, token), - common_kwargs = { - 'msg_amount': msg_amount, - 'token': token, - } - async with tractor.open_nursery() as an: - send_p = await an.start_actor( - 'ring_sender', - enable_modules=[__name__], - proc_kwargs=proc_kwargs - ) - recv_p = await an.start_actor( - 'ring_receiver', - enable_modules=[__name__], - proc_kwargs=proc_kwargs - ) - async with ( - send_p.open_context( - child_write_shm, - rand_min=rand_min, - rand_max=rand_max, - **common_kwargs - ) as (sctx, total_bytes), - recv_p.open_context( - child_read_shm, - **common_kwargs, - total_bytes=total_bytes, - ) as (sctx, _sent), - ): - await recv_p.result() + recv_p.open_context( + child_read_shm, + token=token, + ) as (rctx, _), + ): + sent_hash = await sctx.result() + recvd_hash = await rctx.result() - await send_p.cancel_actor() - await recv_p.cancel_actor() + assert sent_hash == recvd_hash + await an.cancel() trio.run(main) @tractor.context -async def child_blocked_receiver( - ctx: tractor.Context, - token: RBToken -): - async with RingBuffReceiver(token) as receiver: - await ctx.started() +async def child_blocked_receiver(ctx: tractor.Context): + async with ( + open_ringbuf('test_ring_cancel_reader') as token, + + attach_to_ringbuf_receiver(token) as receiver + ): + await ctx.started(token) await receiver.receive_some() -def test_ring_reader_cancel(): +def test_reader_cancel(): + ''' + Test that a receiver blocked on eventfd(2) read responds to + cancellation. + + ''' async def main(): - with open_ringbuf('test_ring_cancel_reader') as token: + async with tractor.open_nursery() as an: + recv_p = await an.start_actor( + 'ring_blocked_receiver', + enable_modules=[ + __name__, + 'tractor.linux._fdshare' + ], + ) async with ( - tractor.open_nursery() as an, - RingBuffSender(token) as _sender, + recv_p.open_context( + child_blocked_receiver, + ) as (sctx, token), + + attach_to_ringbuf_sender(token), ): - recv_p = await an.start_actor( - 'ring_blocked_receiver', - enable_modules=[__name__], - proc_kwargs={ - 'pass_fds': (token.write_eventfd, token.wrap_eventfd) - } - ) - async with ( - recv_p.open_context( - child_blocked_receiver, - token=token - ) as (sctx, _sent), - ): - await trio.sleep(1) - await an.cancel() + await trio.sleep(.1) + await an.cancel() with pytest.raises(tractor._exceptions.ContextCancelled): @@ -174,38 +215,166 @@ def test_ring_reader_cancel(): @tractor.context -async def child_blocked_sender( - ctx: tractor.Context, - token: RBToken -): - async with RingBuffSender(token) as sender: - await ctx.started() +async def child_blocked_sender(ctx: tractor.Context): + async with ( + open_ringbuf( + 'test_ring_cancel_sender', + buf_size=1 + ) as token, + + attach_to_ringbuf_sender(token) as sender + ): + await ctx.started(token) await sender.send_all(b'this will wrap') -def test_ring_sender_cancel(): +def test_sender_cancel(): + ''' + Test that a sender blocked on eventfd(2) read responds to + cancellation. + + ''' async def main(): - with open_ringbuf( - 'test_ring_cancel_sender', - buf_size=1 - ) as token: - async with tractor.open_nursery() as an: - recv_p = await an.start_actor( - 'ring_blocked_sender', - enable_modules=[__name__], - proc_kwargs={ - 'pass_fds': (token.write_eventfd, token.wrap_eventfd) - } - ) - async with ( - recv_p.open_context( - child_blocked_sender, - token=token - ) as (sctx, _sent), - ): - await trio.sleep(1) - await an.cancel() + async with tractor.open_nursery() as an: + recv_p = await an.start_actor( + 'ring_blocked_sender', + enable_modules=[ + __name__, + 'tractor.linux._fdshare' + ], + ) + async with ( + recv_p.open_context( + child_blocked_sender, + ) as (sctx, token), + + attach_to_ringbuf_receiver(token) + ): + await trio.sleep(.1) + await an.cancel() with pytest.raises(tractor._exceptions.ContextCancelled): trio.run(main) + + +def test_receiver_max_bytes(): + ''' + Test that RingBuffReceiver.receive_some's max_bytes optional + argument works correctly, send a msg of size 100, then + force receive of messages with max_bytes == 1, wait until + 100 of these messages are received, then compare join of + msgs with original message + + ''' + msg = generate_single_byte_msgs(100) + msgs = [] + + rb_common = { + 'cleanup': False, + 'is_ipc': False + } + + async def main(): + async with ( + open_ringbuf( + 'test_ringbuf_max_bytes', + buf_size=10, + is_ipc=False + ) as token, + + trio.open_nursery() as n, + + attach_to_ringbuf_sender(token, **rb_common) as sender, + + attach_to_ringbuf_receiver(token, **rb_common) as receiver + ): + async def _send_and_close(): + await sender.send_all(msg) + await sender.aclose() + + n.start_soon(_send_and_close) + while len(msgs) < len(msg): + msg_part = await receiver.receive_some(max_bytes=1) + assert len(msg_part) == 1 + msgs.append(msg_part) + + trio.run(main) + assert msg == b''.join(msgs) + + +@tractor.context +async def child_channel_sender( + ctx: tractor.Context, + msg_amount_min: int, + msg_amount_max: int, + token_in: RBToken, + token_out: RBToken +): + import random + rng = RandomBytesGenerator( + random.randint(msg_amount_min, msg_amount_max), + rand_min=256, + rand_max=1024, + ) + async with attach_to_ringbuf_channel( + token_in, + token_out + ) as chan: + await ctx.started() + for msg in rng: + await chan.send(msg) + + await chan.send(b'bye') + await chan.receive() + return rng.hexdigest + + +def test_channel(): + + msg_amount_min = 100 + msg_amount_max = 1000 + + mods = [ + __name__, + 'tractor.linux._fdshare' + ] + + async def main(): + async with ( + tractor.open_nursery(enable_modules=mods) as an, + + open_ringbuf_pair( + 'test_ringbuf_transport' + ) as (send_token, recv_token), + + attach_to_ringbuf_channel(send_token, recv_token) as chan, + ): + sender = await an.start_actor( + 'test_ringbuf_transport_sender', + enable_modules=mods, + ) + async with ( + sender.open_context( + child_channel_sender, + msg_amount_min=msg_amount_min, + msg_amount_max=msg_amount_max, + token_in=recv_token, + token_out=send_token + ) as (ctx, _), + ): + recvd_hash = hashlib.sha256() + async for msg in chan: + if msg == b'bye': + await chan.send(b'bye') + break + + recvd_hash.update(msg) + + sent_hash = await ctx.result() + + assert recvd_hash.hexdigest() == sent_hash + + await an.cancel() + + trio.run(main) diff --git a/tractor/_discovery.py b/tractor/_discovery.py index fd3e4b1c..478538c6 100644 --- a/tractor/_discovery.py +++ b/tractor/_discovery.py @@ -121,9 +121,14 @@ def get_peer_by_name( actor: Actor = current_actor() server: IPCServer = actor.ipc_server to_scan: dict[tuple, list[Channel]] = server._peers.copy() - pchan: Channel|None = actor._parent_chan - if pchan: - to_scan[pchan.uid].append(pchan) + + # TODO: is this ever needed? creates a duplicate channel on actor._peers + # when multiple find_actor calls are made to same actor from a single ctx + # which causes actor exit to hang waiting forever on + # `actor._no_more_peers.wait()` in `_runtime.async_main` + # pchan: Channel|None = actor._parent_chan + # if pchan: + # to_scan[pchan.uid].append(pchan) for aid, chans in to_scan.items(): _, peer_name = aid diff --git a/tractor/_testing/samples.py b/tractor/_testing/samples.py index a87a22c4..fcf41dfa 100644 --- a/tractor/_testing/samples.py +++ b/tractor/_testing/samples.py @@ -1,35 +1,99 @@ -import os -import random +import hashlib +import numpy as np -def generate_sample_messages( - amount: int, - rand_min: int = 0, - rand_max: int = 0, - silent: bool = False -) -> tuple[list[bytes], int]: +def generate_single_byte_msgs(amount: int) -> bytes: + ''' + Generate a byte instance of length `amount` with repeating ASCII digits 0..9. - msgs = [] - size = 0 + ''' + # array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48) + # up to 9->'9'(57). + arr = np.arange(amount, dtype=np.uint8) % 10 + # move into ascii space + arr += 48 + return arr.tobytes() - if not silent: - print(f'\ngenerating {amount} messages...') - for i in range(amount): - msg = f'[{i:08}]'.encode('utf-8') +class RandomBytesGenerator: + ''' + Generate bytes msgs for tests. - if rand_max > 0: - msg += os.urandom( - random.randint(rand_min, rand_max)) + messages will have the following format: - size += len(msg) + b'[{i:08}]' + random_bytes - msgs.append(msg) + so for message index 25: - if not silent and i and i % 10_000 == 0: - print(f'{i} generated') + b'[00000025]' + random_bytes - if not silent: - print(f'done, {size:,} bytes in total') + also generates sha256 hash of msgs. - return msgs, size + ''' + + def __init__( + self, + amount: int, + rand_min: int = 0, + rand_max: int = 0 + ): + if rand_max < rand_min: + raise ValueError('rand_max must be >= rand_min') + + self._amount = amount + self._rand_min = rand_min + self._rand_max = rand_max + self._index = 0 + self._hasher = hashlib.sha256() + self._total_bytes = 0 + + self._lengths = np.random.randint( + rand_min, + rand_max + 1, + size=amount, + dtype=np.int32 + ) + + def __iter__(self): + return self + + def __next__(self) -> bytes: + if self._index == self._amount: + raise StopIteration + + header = f'[{self._index:08}]'.encode('utf-8') + + length = int(self._lengths[self._index]) + msg = header + np.random.bytes(length) + + self._hasher.update(msg) + self._total_bytes += length + self._index += 1 + + return msg + + @property + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + @property + def total_bytes(self) -> int: + return self._total_bytes + + @property + def total_msgs(self) -> int: + return self._amount + + @property + def msgs_generated(self) -> int: + return self._index + + @property + def recommended_log_interval(self) -> int: + max_msg_size = 10 + self._rand_max + + if max_msg_size <= 32 * 1024: + return 10_000 + + else: + return 1000 diff --git a/tractor/ipc/__init__.py b/tractor/ipc/__init__.py index 2c6c3b5d..37c3c8ed 100644 --- a/tractor/ipc/__init__.py +++ b/tractor/ipc/__init__.py @@ -13,7 +13,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - ''' A modular IPC layer supporting the power of cross-process SC! diff --git a/tractor/ipc/_ringbuf.py b/tractor/ipc/_ringbuf.py deleted file mode 100644 index 6337eea1..00000000 --- a/tractor/ipc/_ringbuf.py +++ /dev/null @@ -1,253 +0,0 @@ -# tractor: structured concurrent "actors". -# Copyright 2018-eternity Tyler Goodlet. - -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. - -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -''' -IPC Reliable RingBuffer implementation - -''' -from __future__ import annotations -from contextlib import contextmanager as cm -from multiprocessing.shared_memory import SharedMemory - -import trio -from msgspec import ( - Struct, - to_builtins -) - -from ._linux import ( - EFD_NONBLOCK, - open_eventfd, - EventFD -) -from ._mp_bs import disable_mantracker - - -disable_mantracker() - - -class RBToken(Struct, frozen=True): - ''' - RingBuffer token contains necesary info to open the two - eventfds and the shared memory - - ''' - shm_name: str - write_eventfd: int - wrap_eventfd: int - buf_size: int - - def as_msg(self): - return to_builtins(self) - - @classmethod - def from_msg(cls, msg: dict) -> RBToken: - if isinstance(msg, RBToken): - return msg - - return RBToken(**msg) - - -@cm -def open_ringbuf( - shm_name: str, - buf_size: int = 10 * 1024, - write_efd_flags: int = 0, - wrap_efd_flags: int = 0 -) -> RBToken: - shm = SharedMemory( - name=shm_name, - size=buf_size, - create=True - ) - try: - token = RBToken( - shm_name=shm_name, - write_eventfd=open_eventfd(flags=write_efd_flags), - wrap_eventfd=open_eventfd(flags=wrap_efd_flags), - buf_size=buf_size - ) - yield token - - finally: - shm.unlink() - - -class RingBuffSender(trio.abc.SendStream): - ''' - IPC Reliable Ring Buffer sender side implementation - - `eventfd(2)` is used for wrap around sync, and also to signal - writes to the reader. - - ''' - def __init__( - self, - token: RBToken, - start_ptr: int = 0, - ): - token = RBToken.from_msg(token) - self._shm = SharedMemory( - name=token.shm_name, - size=token.buf_size, - create=False - ) - self._write_event = EventFD(token.write_eventfd, 'w') - self._wrap_event = EventFD(token.wrap_eventfd, 'r') - self._ptr = start_ptr - - @property - def key(self) -> str: - return self._shm.name - - @property - def size(self) -> int: - return self._shm.size - - @property - def ptr(self) -> int: - return self._ptr - - @property - def write_fd(self) -> int: - return self._write_event.fd - - @property - def wrap_fd(self) -> int: - return self._wrap_event.fd - - async def send_all(self, data: bytes | bytearray | memoryview): - # while data is larger than the remaining buf - target_ptr = self.ptr + len(data) - while target_ptr > self.size: - # write all bytes that fit - remaining = self.size - self.ptr - self._shm.buf[self.ptr:] = data[:remaining] - # signal write and wait for reader wrap around - self._write_event.write(remaining) - await self._wrap_event.read() - - # wrap around and trim already written bytes - self._ptr = 0 - data = data[remaining:] - target_ptr = self._ptr + len(data) - - # remaining data fits on buffer - self._shm.buf[self.ptr:target_ptr] = data - self._write_event.write(len(data)) - self._ptr = target_ptr - - async def wait_send_all_might_not_block(self): - raise NotImplementedError - - async def aclose(self): - self._write_event.close() - self._wrap_event.close() - self._shm.close() - - async def __aenter__(self): - self._write_event.open() - self._wrap_event.open() - return self - - -class RingBuffReceiver(trio.abc.ReceiveStream): - ''' - IPC Reliable Ring Buffer receiver side implementation - - `eventfd(2)` is used for wrap around sync, and also to signal - writes to the reader. - - ''' - def __init__( - self, - token: RBToken, - start_ptr: int = 0, - flags: int = 0 - ): - token = RBToken.from_msg(token) - self._shm = SharedMemory( - name=token.shm_name, - size=token.buf_size, - create=False - ) - self._write_event = EventFD(token.write_eventfd, 'w') - self._wrap_event = EventFD(token.wrap_eventfd, 'r') - self._ptr = start_ptr - self._flags = flags - - @property - def key(self) -> str: - return self._shm.name - - @property - def size(self) -> int: - return self._shm.size - - @property - def ptr(self) -> int: - return self._ptr - - @property - def write_fd(self) -> int: - return self._write_event.fd - - @property - def wrap_fd(self) -> int: - return self._wrap_event.fd - - async def receive_some( - self, - max_bytes: int | None = None, - nb_timeout: float = 0.1 - ) -> memoryview: - # if non blocking eventfd enabled, do polling - # until next write, this allows signal handling - if self._flags | EFD_NONBLOCK: - delta = None - while delta is None: - try: - delta = await self._write_event.read() - - except OSError as e: - if e.errno == 'EAGAIN': - continue - - raise e - - else: - delta = await self._write_event.read() - - # fetch next segment and advance ptr - next_ptr = self._ptr + delta - segment = self._shm.buf[self._ptr:next_ptr] - self._ptr = next_ptr - - if self.ptr == self.size: - # reached the end, signal wrap around - self._ptr = 0 - self._wrap_event.write(1) - - return segment - - async def aclose(self): - self._write_event.close() - self._wrap_event.close() - self._shm.close() - - async def __aenter__(self): - self._write_event.open() - self._wrap_event.open() - return self diff --git a/tractor/ipc/_ringbuf/__init__.py b/tractor/ipc/_ringbuf/__init__.py new file mode 100644 index 00000000..12943707 --- /dev/null +++ b/tractor/ipc/_ringbuf/__init__.py @@ -0,0 +1,1004 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +IPC Reliable RingBuffer implementation + +''' +from __future__ import annotations +import struct +from typing import ( + TypeVar, + ContextManager, + AsyncContextManager +) +from contextlib import ( + contextmanager as cm, + asynccontextmanager as acm +) +from multiprocessing.shared_memory import SharedMemory + +import trio +from msgspec import ( + Struct, + to_builtins +) +from msgspec.msgpack import ( + Encoder, + Decoder, +) + +from tractor.log import get_logger +from tractor._exceptions import ( + InternalError +) +from tractor.ipc._mp_bs import disable_mantracker +from tractor.linux._fdshare import ( + share_fds, + unshare_fds, + request_fds_from +) +from tractor.linux.eventfd import ( + open_eventfd, + EFDReadCancelled, + EventFD +) +from tractor._state import current_actor + + +log = get_logger(__name__) + + +disable_mantracker() + +_DEFAULT_RB_SIZE = 10 * 1024 + + +class RBToken(Struct, frozen=True): + ''' + RingBuffer token contains necesary info to open resources of a ringbuf, + even in the case that ringbuf was not allocated by current actor. + + ''' + owner: str | None # if owner != `current_actor().name` we must use fdshare + + shm_name: str + + write_eventfd: int # used to signal writer ptr advance + wrap_eventfd: int # used to signal reader ready after wrap around + eof_eventfd: int # used to signal writer closed + + buf_size: int # size in bytes of underlying shared memory buffer + + def as_msg(self): + return to_builtins(self) + + @classmethod + def from_msg(cls, msg: dict) -> RBToken: + if isinstance(msg, RBToken): + return msg + + return RBToken(**msg) + + @property + def fds(self) -> tuple[int, int, int]: + return ( + self.write_eventfd, + self.wrap_eventfd, + self.eof_eventfd + ) + + +def alloc_ringbuf( + shm_name: str, + buf_size: int = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> tuple[SharedMemory, RBToken]: + ''' + Allocate OS resources for a ringbuf. + ''' + shm = SharedMemory( + name=shm_name, + size=buf_size, + create=True + ) + token = RBToken( + owner=current_actor().name if is_ipc else None, + shm_name=shm_name, + write_eventfd=open_eventfd(), + wrap_eventfd=open_eventfd(), + eof_eventfd=open_eventfd(), + buf_size=buf_size + ) + + if is_ipc: + # register fds for sharing + share_fds( + shm_name, + token.fds, + ) + + return shm, token + + +@cm +def open_ringbuf_sync( + shm_name: str, + buf_size: int = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> ContextManager[RBToken]: + ''' + Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to + be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`, + post yield maybe unshare fds and unlink shared memory + + ''' + shm: SharedMemory | None = None + token: RBToken | None = None + try: + shm, token = alloc_ringbuf( + shm_name, + buf_size=buf_size, + is_ipc=is_ipc + ) + yield token + + finally: + if token and is_ipc: + unshare_fds(shm_name) + + if shm: + shm.unlink() + +@acm +async def open_ringbuf( + shm_name: str, + buf_size: int = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> AsyncContextManager[RBToken]: + ''' + Helper to use `open_ringbuf_sync` inside an async with block. + + ''' + with open_ringbuf_sync( + shm_name, + buf_size=buf_size, + is_ipc=is_ipc + ) as token: + yield token + + +@cm +def open_ringbufs_sync( + shm_names: list[str], + buf_sizes: int | list[str] = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> ContextManager[tuple[RBToken]]: + ''' + Handle resources for multiple ringbufs at once. + + ''' + # maybe convert single int into list + if isinstance(buf_sizes, int): + buf_size = [buf_sizes] * len(shm_names) + + # ensure len(shm_names) == len(buf_sizes) + if ( + isinstance(buf_sizes, list) + and + len(buf_sizes) != len(shm_names) + ): + raise ValueError( + 'Expected buf_size list to be same length as shm_names' + ) + + # allocate resources + rings: list[tuple[SharedMemory, RBToken]] = [ + alloc_ringbuf( + shm_name, + buf_size=buf_size, + is_ipc=is_ipc + ) + for shm_name, buf_size in zip(shm_names, buf_size) + ] + + try: + yield tuple([token for _, token in rings]) + + finally: + # attempt fd unshare and shm unlink for each + for shm, token in rings: + if is_ipc: + try: + unshare_fds(token.shm_name) + + except RuntimeError: + log.exception(f'while unsharing fds of {token}') + + shm.unlink() + + +@acm +async def open_ringbufs( + shm_names: list[str], + buf_sizes: int | list[str] = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> AsyncContextManager[tuple[RBToken]]: + ''' + Helper to use `open_ringbufs_sync` inside an async with block. + + ''' + with open_ringbufs_sync( + shm_names, + buf_sizes=buf_sizes, + is_ipc=is_ipc + ) as tokens: + yield tokens + + +@cm +def open_ringbuf_pair_sync( + shm_name: str, + buf_size: int = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> ContextManager[tuple(RBToken, RBToken)]: + ''' + Handle resources for a ringbuf pair to be used for + bidirectional messaging. + + ''' + with open_ringbufs_sync( + [ + f'{shm_name}.send', + f'{shm_name}.recv' + ], + buf_sizes=buf_size, + is_ipc=is_ipc + ) as tokens: + yield tokens + + +@acm +async def open_ringbuf_pair( + shm_name: str, + buf_size: int = _DEFAULT_RB_SIZE, + is_ipc: bool = True +) -> AsyncContextManager[tuple[RBToken, RBToken]]: + ''' + Helper to use `open_ringbuf_pair_sync` inside an async with block. + + ''' + with open_ringbuf_pair_sync( + shm_name, + buf_size=buf_size, + is_ipc=is_ipc + ) as tokens: + yield tokens + + +Buffer = bytes | bytearray | memoryview + + +''' +IPC Reliable Ring Buffer + +`eventfd(2)` is used for wrap around sync, to signal writes to +the reader and end of stream. + +In order to guarantee full messages are received, all bytes +sent by `RingBufferSendChannel` are preceded with a 4 byte header +which decodes into a uint32 indicating the actual size of the +next full payload. + +''' + + +PayloadT = TypeVar('PayloadT') + + +class RingBufferSendChannel(trio.abc.SendChannel[PayloadT]): + ''' + Ring Buffer sender side implementation + + Do not use directly! manage with `attach_to_ringbuf_sender` + after having opened a ringbuf context with `open_ringbuf`. + + Optional batch mode: + + If `batch_size` > 1 messages wont get sent immediately but will be + stored until `batch_size` messages are pending, then it will send + them all at once. + + `batch_size` can be changed dynamically but always call, `flush()` + right before. + + ''' + def __init__( + self, + token: RBToken, + batch_size: int = 1, + cleanup: bool = False, + encoder: Encoder | None = None + ): + self._token = RBToken.from_msg(token) + self.batch_size = batch_size + + # ringbuf os resources + self._shm: SharedMemory | None = None + self._write_event = EventFD(self._token.write_eventfd, 'w') + self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') + self._eof_event = EventFD(self._token.eof_eventfd, 'w') + + # current write pointer + self._ptr: int = 0 + + # when `batch_size` > 1 store messages on `self._batch` and write them + # all, once `len(self._batch) == `batch_size` + self._batch: list[bytes] = [] + + # close shm & fds on exit? + self._cleanup: bool = cleanup + + self._enc: Encoder | None = encoder + + # have we closed this ringbuf? + # set to `False` on `.open()` + self._is_closed: bool = True + + # ensure no concurrent `.send_all()` calls + self._send_all_lock = trio.StrictFIFOLock() + + # ensure no concurrent `.send()` calls + self._send_lock = trio.StrictFIFOLock() + + # ensure no concurrent `.flush()` calls + self._flush_lock = trio.StrictFIFOLock() + + @property + def closed(self) -> bool: + return self._is_closed + + @property + def name(self) -> str: + if not self._shm: + raise ValueError('shared memory not initialized yet!') + return self._shm.name + + @property + def size(self) -> int: + return self._token.buf_size + + @property + def ptr(self) -> int: + return self._ptr + + @property + def write_fd(self) -> int: + return self._write_event.fd + + @property + def wrap_fd(self) -> int: + return self._wrap_event.fd + + @property + def pending_msgs(self) -> int: + return len(self._batch) + + @property + def must_flush(self) -> bool: + return self.pending_msgs >= self.batch_size + + async def _wait_wrap(self): + await self._wrap_event.read() + + async def send_all(self, data: Buffer): + if self.closed: + raise trio.ClosedResourceError + + if self._send_all_lock.locked(): + raise trio.BusyResourceError + + async with self._send_all_lock: + # while data is larger than the remaining buf + target_ptr = self.ptr + len(data) + while target_ptr > self.size: + # write all bytes that fit + remaining = self.size - self.ptr + self._shm.buf[self.ptr:] = data[:remaining] + # signal write and wait for reader wrap around + self._write_event.write(remaining) + await self._wait_wrap() + + # wrap around and trim already written bytes + self._ptr = 0 + data = data[remaining:] + target_ptr = self._ptr + len(data) + + # remaining data fits on buffer + self._shm.buf[self.ptr:target_ptr] = data + self._write_event.write(len(data)) + self._ptr = target_ptr + + async def wait_send_all_might_not_block(self): + return + + async def flush( + self, + new_batch_size: int | None = None + ) -> None: + if self.closed: + raise trio.ClosedResourceError + + async with self._flush_lock: + for msg in self._batch: + await self.send_all(msg) + + self._batch = [] + if new_batch_size: + self.batch_size = new_batch_size + + async def send(self, value: PayloadT) -> None: + if self.closed: + raise trio.ClosedResourceError + + if self._send_lock.locked(): + raise trio.BusyResourceError + + raw_value: bytes = ( + value + if isinstance(value, bytes) + else + self._enc.encode(value) + ) + + async with self._send_lock: + msg: bytes = struct.pack(" 0: + await self.flush() + + await self.send_all(msg) + return + + self._batch.append(msg) + if self.must_flush: + await self.flush() + + def open(self): + try: + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) + self._write_event.open() + self._wrap_event.open() + self._eof_event.open() + self._is_closed = False + + except Exception as e: + e.add_note(f'while opening sender for {self._token.as_msg()}') + raise e + + def _close(self): + self._eof_event.write( + self._ptr if self._ptr > 0 else self.size + ) + + if self._cleanup: + self._write_event.close() + self._wrap_event.close() + self._eof_event.close() + self._shm.close() + + self._is_closed = True + + async def aclose(self): + if self.closed: + return + + self._close() + + async def __aenter__(self): + self.open() + return self + + +class RingBufferReceiveChannel(trio.abc.ReceiveChannel[PayloadT]): + ''' + Ring Buffer receiver side implementation + + Do not use directly! manage with `attach_to_ringbuf_receiver` + after having opened a ringbuf context with `open_ringbuf`. + + ''' + def __init__( + self, + token: RBToken, + cleanup: bool = True, + decoder: Decoder | None = None + ): + self._token = RBToken.from_msg(token) + + # ringbuf os resources + self._shm: SharedMemory | None = None + self._write_event = EventFD(self._token.write_eventfd, 'w') + self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') + self._eof_event = EventFD(self._token.eof_eventfd, 'r') + + # current read ptr + self._ptr: int = 0 + + # current write_ptr (max bytes we can read from buf) + self._write_ptr: int = 0 + + # end ptr is used when EOF is signaled, it will contain maximun + # readable position on buf + self._end_ptr: int = -1 + + # close shm & fds on exit? + self._cleanup: bool = cleanup + + # have we closed this ringbuf? + # set to `False` on `.open()` + self._is_closed: bool = True + + self._dec: Decoder | None = decoder + + # ensure no concurrent `.receive_some()` calls + self._receive_some_lock = trio.StrictFIFOLock() + + # ensure no concurrent `.receive_exactly()` calls + self._receive_exactly_lock = trio.StrictFIFOLock() + + # ensure no concurrent `.receive()` calls + self._receive_lock = trio.StrictFIFOLock() + + @property + def closed(self) -> bool: + return self._is_closed + + @property + def name(self) -> str: + if not self._shm: + raise ValueError('shared memory not initialized yet!') + return self._shm.name + + @property + def size(self) -> int: + return self._token.buf_size + + @property + def ptr(self) -> int: + return self._ptr + + @property + def write_fd(self) -> int: + return self._write_event.fd + + @property + def wrap_fd(self) -> int: + return self._wrap_event.fd + + @property + def eof_was_signaled(self) -> bool: + return self._end_ptr != -1 + + async def _eof_monitor_task(self): + ''' + Long running EOF event monitor, automatically run in bg by + `attach_to_ringbuf_receiver` context manager, if EOF event + is set its value will be the end pointer (highest valid + index to be read from buf, after setting the `self._end_ptr` + we close the write event which should cancel any blocked + `self._write_event.read()`s on it. + + ''' + try: + self._end_ptr = await self._eof_event.read() + + except EFDReadCancelled: + ... + + except trio.Cancelled: + ... + + finally: + # closing write_event should trigger `EFDReadCancelled` + # on any pending read + self._write_event.close() + + def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: + ''' + Try to receive any bytes we can without blocking or raise + `trio.WouldBlock`. + + Returns b'' when no more bytes can be read (EOF signaled & read all). + + ''' + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + + # in case `end_ptr` is set that means eof was signaled. + # it will be >= `write_ptr`, use it for delta calc + highest_ptr = max(self._write_ptr, self._end_ptr) + + delta = highest_ptr - self._ptr + + # no more bytes to read + if delta == 0: + # if `end_ptr` is set that means we read all bytes before EOF + if self.eof_was_signaled: + return b'' + + # signal the need to wait on `write_event` + raise trio.WouldBlock + + # dont overflow caller + delta = min(delta, max_bytes) + + target_ptr = self._ptr + delta + + # fetch next segment and advance ptr + segment = bytes(self._shm.buf[self._ptr:target_ptr]) + self._ptr = target_ptr + + if self._ptr == self.size: + # reached the end, signal wrap around + self._ptr = 0 + self._write_ptr = 0 + self._wrap_event.write(1) + + return segment + + async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: + ''' + Receive up to `max_bytes`, if no `max_bytes` is provided + a reasonable default is used. + + Can return < max_bytes. + + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._receive_some_lock.locked(): + raise trio.BusyResourceError + + async with self._receive_some_lock: + try: + # attempt direct read + return self.receive_nowait(max_bytes=max_bytes) + + except trio.WouldBlock as e: + # we have read all we can, see if new data is available + if not self.eof_was_signaled: + # if we havent been signaled about EOF yet + try: + # wait next write and advance `write_ptr` + delta = await self._write_event.read() + self._write_ptr += delta + # yield lock and re-enter + + except ( + EFDReadCancelled, # read was cancelled with cscope + trio.Cancelled, # read got cancelled from outside + trio.BrokenResourceError # OSError EBADF happened while reading + ): + # while waiting for new data `self._write_event` was closed + try: + # if eof was signaled receive no wait will not raise + # trio.WouldBlock and will push remaining until EOF + return self.receive_nowait(max_bytes=max_bytes) + + except trio.WouldBlock: + # eof was not signaled but `self._wrap_event` is closed + # this means send side closed without EOF signal + return b'' + + else: + # shouldnt happen because receive_nowait does not raise + # trio.WouldBlock when `end_ptr` is set + raise InternalError( + 'self._end_ptr is set but receive_nowait raised trio.WouldBlock' + ) from e + + return await self.receive_some(max_bytes=max_bytes) + + async def receive_exactly(self, num_bytes: int) -> bytes: + ''' + Fetch bytes until we read exactly `num_bytes` or EOC. + + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._receive_exactly_lock.locked(): + raise trio.BusyResourceError + + async with self._receive_exactly_lock: + payload = b'' + while len(payload) < num_bytes: + remaining = num_bytes - len(payload) + + new_bytes = await self.receive_some( + max_bytes=remaining + ) + + if new_bytes == b'': + break + + payload += new_bytes + + if payload == b'': + raise trio.EndOfChannel + + return payload + + async def receive(self, raw: bool = False) -> PayloadT: + ''' + Receive a complete payload or raise EOC + + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._receive_lock.locked(): + raise trio.BusyResourceError + + async with self._receive_lock: + header: bytes = await self.receive_exactly(4) + size: int + size, = struct.unpack(" tuple[bytes, PayloadT]: + if not self._dec: + raise RuntimeError('iter_raw_pair requires decoder') + + while True: + try: + raw = await self.receive(raw=True) + yield raw, self._dec.decode(raw) + + except trio.EndOfChannel: + break + + def open(self): + try: + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) + self._write_event.open() + self._wrap_event.open() + self._eof_event.open() + self._is_closed = False + + except Exception as e: + e.add_note(f'while opening receiver for {self._token.as_msg()}') + raise e + + def close(self): + if self._cleanup: + self._write_event.close() + self._wrap_event.close() + self._eof_event.close() + self._shm.close() + + self._is_closed = True + + async def aclose(self): + if self.closed: + return + + self.close() + + async def __aenter__(self): + self.open() + return self + + +async def _maybe_obtain_shared_resources(token: RBToken): + token = RBToken.from_msg(token) + + # maybe token wasn't allocated by current actor + if token.owner != current_actor().name: + # use fdshare module to retrieve a copy of the FDs + fds = await request_fds_from( + token.owner, + token.shm_name + ) + write, wrap, eof = fds + # rebuild token using FDs copies + token = RBToken( + owner=token.owner, + shm_name=token.shm_name, + write_eventfd=write, + wrap_eventfd=wrap, + eof_eventfd=eof, + buf_size=token.buf_size + ) + + return token + +@acm +async def attach_to_ringbuf_receiver( + + token: RBToken, + cleanup: bool = True, + decoder: Decoder | None = None, + is_ipc: bool = True + +) -> AsyncContextManager[RingBufferReceiveChannel]: + ''' + Attach a RingBufferReceiveChannel from a previously opened + RBToken. + + Requires tractor runtime to be up in order to support opening a ringbuf + originally allocated by a different actor. + + Launches `receiver._eof_monitor_task` in a `trio.Nursery`. + ''' + if is_ipc: + token = await _maybe_obtain_shared_resources(token) + + async with ( + trio.open_nursery(strict_exception_groups=False) as n, + RingBufferReceiveChannel( + token, + cleanup=cleanup, + decoder=decoder + ) as receiver + ): + n.start_soon(receiver._eof_monitor_task) + yield receiver + + +@acm +async def attach_to_ringbuf_sender( + + token: RBToken, + batch_size: int = 1, + cleanup: bool = True, + encoder: Encoder | None = None, + is_ipc: bool = True + +) -> AsyncContextManager[RingBufferSendChannel]: + ''' + Attach a RingBufferSendChannel from a previously opened + RBToken. + + Requires tractor runtime to be up in order to support opening a ringbuf + originally allocated by a different actor. + + ''' + if is_ipc: + token = await _maybe_obtain_shared_resources(token) + + async with RingBufferSendChannel( + token, + batch_size=batch_size, + cleanup=cleanup, + encoder=encoder + ) as sender: + yield sender + + +class RingBufferChannel(trio.abc.Channel[bytes]): + ''' + Combine `RingBufferSendChannel` and `RingBufferReceiveChannel` + in order to expose the bidirectional `trio.abc.Channel` API. + + ''' + def __init__( + self, + sender: RingBufferSendChannel, + receiver: RingBufferReceiveChannel + ): + self._sender = sender + self._receiver = receiver + + @property + def batch_size(self) -> int: + return self._sender.batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + self._sender.batch_size = value + + @property + def pending_msgs(self) -> int: + return self._sender.pending_msgs + + async def send_all(self, value: bytes) -> None: + await self._sender.send_all(value) + + async def wait_send_all_might_not_block(self): + await self._sender.wait_send_all_might_not_block() + + async def flush( + self, + new_batch_size: int | None = None + ) -> None: + await self._sender.flush(new_batch_size=new_batch_size) + + async def send(self, value: bytes) -> None: + await self._sender.send(value) + + async def send_eof(self) -> None: + await self._sender.send_eof() + + def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: + return self._receiver.receive_nowait(max_bytes=max_bytes) + + async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: + return await self._receiver.receive_some(max_bytes=max_bytes) + + async def receive_exactly(self, num_bytes: int) -> bytes: + return await self._receiver.receive_exactly(num_bytes) + + async def receive(self) -> bytes: + return await self._receiver.receive() + + async def aclose(self): + await self._receiver.aclose() + await self._sender.aclose() + + +@acm +async def attach_to_ringbuf_channel( + token_in: RBToken, + token_out: RBToken, + batch_size: int = 1, + cleanup_in: bool = True, + cleanup_out: bool = True, + encoder: Encoder | None = None, + decoder: Decoder | None = None, + sender_ipc: bool = True, + receiver_ipc: bool = True +) -> AsyncContextManager[trio.StapledStream]: + ''' + Attach to two previously opened `RBToken`s and return a `RingBufferChannel` + + ''' + async with ( + attach_to_ringbuf_receiver( + token_in, + cleanup=cleanup_in, + decoder=decoder, + is_ipc=receiver_ipc + ) as receiver, + attach_to_ringbuf_sender( + token_out, + batch_size=batch_size, + cleanup=cleanup_out, + encoder=encoder, + is_ipc=sender_ipc + ) as sender, + ): + yield RingBufferChannel(sender, receiver) diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py new file mode 100644 index 00000000..c85de9ca --- /dev/null +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -0,0 +1,834 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +Ring buffer ipc publish-subscribe mechanism brokered by ringd +can dynamically add new outputs (publisher) or inputs (subscriber) +''' +from typing import ( + TypeVar, + Generic, + Callable, + Awaitable, + AsyncContextManager +) +from functools import partial +from contextlib import asynccontextmanager as acm +from dataclasses import dataclass + +import trio +import tractor + +from msgspec.msgpack import ( + Encoder, + Decoder +) + +from tractor.ipc._ringbuf import ( + RBToken, + PayloadT, + RingBufferSendChannel, + RingBufferReceiveChannel, + attach_to_ringbuf_sender, + attach_to_ringbuf_receiver +) + +from tractor.trionics import ( + order_send_channel, + order_receive_channel +) + +import tractor.linux._fdshare as fdshare + + +log = tractor.log.get_logger(__name__) + + +ChannelType = TypeVar('ChannelType') + + +@dataclass +class ChannelInfo: + token: RBToken + channel: ChannelType + cancel_scope: trio.CancelScope + teardown: trio.Event + + +class ChannelManager(Generic[ChannelType]): + ''' + Helper for managing channel resources and their handler tasks with + cancellation, add or remove channels dynamically! + + ''' + + def __init__( + self, + # nursery used to spawn channel handler tasks + n: trio.Nursery, + + # acm will be used for setup & teardown of channel resources + open_channel_acm: Callable[..., AsyncContextManager[ChannelType]], + + # long running bg task to handle channel + channel_task: Callable[..., Awaitable[None]] + ): + self._n = n + self._open_channel = open_channel_acm + self._channel_task = channel_task + + # signal when a new channel conects and we previously had none + self._connect_event = trio.Event() + + # store channel runtime variables + self._channels: list[ChannelInfo] = [] + + self._is_closed: bool = True + + @property + def closed(self) -> bool: + return self._is_closed + + @property + def channels(self) -> list[ChannelInfo]: + return self._channels + + async def _channel_handler_task( + self, + token: RBToken, + task_status=trio.TASK_STATUS_IGNORED, + **kwargs + ): + ''' + Open channel resources, add to internal data structures, signal channel + connect through trio.Event, and run `channel_task` with cancel scope, + and finally, maybe remove channel from internal data structures. + + Spawned by `add_channel` function, lock is held from begining of fn + until `task_status.started()` call. + + kwargs are proxied to `self._open_channel` acm. + ''' + async with self._open_channel( + token, + **kwargs + ) as chan: + cancel_scope = trio.CancelScope() + info = ChannelInfo( + token=token, + channel=chan, + cancel_scope=cancel_scope, + teardown=trio.Event() + ) + self._channels.append(info) + + if len(self) == 1: + self._connect_event.set() + + task_status.started() + + with cancel_scope: + await self._channel_task(info) + + self._maybe_destroy_channel(token.shm_name) + + def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: + ''' + Given a channel name maybe return its index and value from + internal _channels list. + + Only use after acquiring lock. + ''' + for entry in enumerate(self._channels): + i, info = entry + if info.token.shm_name == name: + return entry + + return None + + + def _maybe_destroy_channel(self, name: str): + ''' + If channel exists cancel its scope and remove from internal + _channels list. + + ''' + maybe_entry = self._find_channel(name) + if maybe_entry: + i, info = maybe_entry + info.cancel_scope.cancel() + info.teardown.set() + del self._channels[i] + + async def add_channel( + self, + token: RBToken, + **kwargs + ): + ''' + Add a new channel to be handled + + ''' + if self.closed: + raise trio.ClosedResourceError + + await self._n.start(partial( + self._channel_handler_task, + RBToken.from_msg(token), + **kwargs + )) + + async def remove_channel(self, name: str): + ''' + Remove a channel and stop its handling + + ''' + if self.closed: + raise trio.ClosedResourceError + + maybe_entry = self._find_channel(name) + if not maybe_entry: + # return + raise RuntimeError( + f'tried to remove channel {name} but if does not exist' + ) + + i, info = maybe_entry + self._maybe_destroy_channel(name) + + await info.teardown.wait() + + # if that was last channel reset connect event + if len(self) == 0: + self._connect_event = trio.Event() + + async def wait_for_channel(self): + ''' + Wait until at least one channel added + + ''' + if self.closed: + raise trio.ClosedResourceError + + await self._connect_event.wait() + self._connect_event = trio.Event() + + def __len__(self) -> int: + return len(self._channels) + + def __getitem__(self, name: str): + maybe_entry = self._find_channel(name) + if maybe_entry: + _, info = maybe_entry + return info + + raise KeyError(f'Channel {name} not found!') + + def open(self): + self._is_closed = False + + async def close(self) -> None: + if self.closed: + log.warning('tried to close ChannelManager but its already closed...') + return + + for info in self._channels: + if info.channel.closed: + continue + + await info.channel.aclose() + await self.remove_channel(info.token.shm_name) + + self._is_closed = True + + +''' +Ring buffer publisher & subscribe pattern mediated by `ringd` actor. + +''' + + +class RingBufferPublisher(trio.abc.SendChannel[PayloadT]): + ''' + Use ChannelManager to create a multi ringbuf round robin sender that can + dynamically add or remove more outputs. + + Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its + lifecycle. + + ''' + def __init__( + self, + n: trio.Nursery, + + # amount of msgs to each ring before switching turns + msgs_per_turn: int = 1, + + # global batch size for all channels + batch_size: int = 1, + + encoder: Encoder | None = None + ): + self._batch_size: int = batch_size + self.msgs_per_turn = msgs_per_turn + self._enc = encoder + + # helper to manage acms + long running tasks + self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]]( + n, + self._open_channel, + self._channel_task + ) + + # ensure no concurrent `.send()` calls + self._send_lock = trio.StrictFIFOLock() + + # index of channel to be used for next send + self._next_turn: int = 0 + # amount of messages sent this turn + self._turn_msgs: int = 0 + # have we closed this publisher? + # set to `False` on `.__aenter__()` + self._is_closed: bool = True + + @property + def closed(self) -> bool: + return self._is_closed + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + for info in self.channels: + info.channel.batch_size = value + + @property + def channels(self) -> list[ChannelInfo]: + return self._chanmngr.channels + + def _get_next_turn(self) -> int: + ''' + Maybe switch turn and reset self._turn_msgs or just increment it. + Return current turn + ''' + if self._turn_msgs == self.msgs_per_turn: + self._turn_msgs = 0 + self._next_turn += 1 + + if self._next_turn >= len(self.channels): + self._next_turn = 0 + + else: + self._turn_msgs += 1 + + return self._next_turn + + def get_channel(self, name: str) -> ChannelInfo: + ''' + Get underlying ChannelInfo from name + + ''' + return self._chanmngr[name] + + async def add_channel( + self, + token: RBToken, + ): + await self._chanmngr.add_channel(token) + + async def remove_channel(self, name: str): + await self._chanmngr.remove_channel(name) + + @acm + async def _open_channel( + + self, + token: RBToken + + ) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]: + async with attach_to_ringbuf_sender( + token, + batch_size=self._batch_size, + encoder=self._enc + ) as ring: + yield ring + + async def _channel_task(self, info: ChannelInfo) -> None: + ''' + Wait forever until channel cancellation + + ''' + await trio.sleep_forever() + + async def send(self, msg: bytes): + ''' + If no output channels connected, wait until one, then fetch the next + channel based on turn. + + Needs to acquire `self._send_lock` to ensure no concurrent calls. + + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._send_lock.locked(): + raise trio.BusyResourceError + + async with self._send_lock: + # wait at least one decoder connected + if len(self.channels) == 0: + await self._chanmngr.wait_for_channel() + + turn = self._get_next_turn() + + info = self.channels[turn] + await info.channel.send(msg) + + async def broadcast(self, msg: PayloadT): + ''' + Send a msg to all channels, if no channels connected, does nothing. + ''' + if self.closed: + raise trio.ClosedResourceError + + for info in self.channels: + await info.channel.send(msg) + + async def flush(self, new_batch_size: int | None = None): + for info in self.channels: + try: + await info.channel.flush(new_batch_size=new_batch_size) + + except trio.ClosedResourceError: + ... + + async def __aenter__(self): + self._is_closed = False + self._chanmngr.open() + return self + + async def aclose(self) -> None: + if self.closed: + log.warning('tried to close RingBufferPublisher but its already closed...') + return + + await self._chanmngr.close() + + self._is_closed = True + + +class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]): + ''' + Use ChannelManager to create a multi ringbuf receiver that can + dynamically add or remove more inputs and combine all into a single output. + + In order for `self.receive` messages to be returned in order, publisher + will send all payloads as `OrderedPayload` msgpack encoded msgs, this + allows our channel handler tasks to just stash the out of order payloads + inside `self._pending_payloads` and if a in order payload is available + signal through `self._new_payload_event`. + + On `self.receive` we wait until at least one channel is connected, then if + an in order payload is pending, we pop and return it, in case no in order + payload is available wait until next `self._new_payload_event.set()`. + + ''' + def __init__( + self, + n: trio.Nursery, + + decoder: Decoder | None = None + ): + self._dec = decoder + self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]]( + n, + self._open_channel, + self._channel_task + ) + + self._schan, self._rchan = trio.open_memory_channel(0) + + self._is_closed: bool = True + + self._receive_lock = trio.StrictFIFOLock() + + @property + def closed(self) -> bool: + return self._is_closed + + @property + def channels(self) -> list[ChannelInfo]: + return self._chanmngr.channels + + def get_channel(self, name: str): + return self._chanmngr[name] + + async def add_channel( + self, + token: RBToken + ): + await self._chanmngr.add_channel(token) + + async def remove_channel(self, name: str): + await self._chanmngr.remove_channel(name) + + @acm + async def _open_channel( + + self, + token: RBToken + + ) -> AsyncContextManager[RingBufferSendChannel]: + async with attach_to_ringbuf_receiver( + token, + decoder=self._dec + ) as ring: + yield ring + + async def _channel_task(self, info: ChannelInfo) -> None: + ''' + Iterate over receive channel messages, decode them as `OrderedPayload`s + and stash them in `self._pending_payloads`, in case we can pop next in + order payload, signal through setting `self._new_payload_event`. + + ''' + while True: + try: + msg = await info.channel.receive() + await self._schan.send(msg) + + except tractor.linux.eventfd.EFDReadCancelled as e: + # when channel gets removed while we are doing a receive + log.exception(e) + break + + except trio.EndOfChannel: + break + + except trio.ClosedResourceError: + break + + async def receive(self) -> PayloadT: + ''' + Receive next in order msg + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._receive_lock.locked(): + raise trio.BusyResourceError + + async with self._receive_lock: + return await self._rchan.receive() + + async def __aenter__(self): + self._is_closed = False + self._chanmngr.open() + return self + + async def aclose(self) -> None: + if self.closed: + return + + await self._chanmngr.close() + await self._schan.aclose() + await self._rchan.aclose() + + self._is_closed = True + + +''' +Actor module for managing publisher & subscriber channels remotely through +`tractor.context` rpc +''' + +@dataclass +class PublisherEntry: + publisher: RingBufferPublisher | None = None + is_set: trio.Event = trio.Event() + + +_publishers: dict[str, PublisherEntry] = {} + + +def maybe_init_publisher(topic: str) -> PublisherEntry: + entry = _publishers.get(topic, None) + if not entry: + entry = PublisherEntry() + _publishers[topic] = entry + + return entry + + +def set_publisher(topic: str, pub: RingBufferPublisher): + global _publishers + + entry = _publishers.get(topic, None) + if not entry: + entry = maybe_init_publisher(topic) + + if entry.publisher: + raise RuntimeError( + f'publisher for topic {topic} already set on {tractor.current_actor()}' + ) + + entry.publisher = pub + entry.is_set.set() + + +def get_publisher(topic: str = 'default') -> RingBufferPublisher: + entry = _publishers.get(topic, None) + if not entry or not entry.publisher: + raise RuntimeError( + f'{tractor.current_actor()} tried to get publisher' + 'but it\'s not set' + ) + + return entry.publisher + + +async def wait_publisher(topic: str) -> RingBufferPublisher: + entry = maybe_init_publisher(topic) + await entry.is_set.wait() + return entry.publisher + + +@tractor.context +async def _add_pub_channel( + ctx: tractor.Context, + topic: str, + token: RBToken +): + publisher = await wait_publisher(topic) + await publisher.add_channel(token) + + +@tractor.context +async def _remove_pub_channel( + ctx: tractor.Context, + topic: str, + ring_name: str +): + publisher = await wait_publisher(topic) + maybe_token = fdshare.maybe_get_fds(ring_name) + if maybe_token: + await publisher.remove_channel(ring_name) + + +@acm +async def open_pub_channel_at( + actor_name: str, + token: RBToken, + topic: str = 'default', +): + async with tractor.find_actor(actor_name) as portal: + await portal.run(_add_pub_channel, topic=topic, token=token) + try: + yield + + except trio.Cancelled: + log.warning( + 'open_pub_channel_at got cancelled!\n' + f'\tactor_name = {actor_name}\n' + f'\ttoken = {token}\n' + ) + raise + + await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name) + + +@dataclass +class SubscriberEntry: + subscriber: RingBufferSubscriber | None = None + is_set: trio.Event = trio.Event() + + +_subscribers: dict[str, SubscriberEntry] = {} + + +def maybe_init_subscriber(topic: str) -> SubscriberEntry: + entry = _subscribers.get(topic, None) + if not entry: + entry = SubscriberEntry() + _subscribers[topic] = entry + + return entry + + +def set_subscriber(topic: str, sub: RingBufferSubscriber): + global _subscribers + + entry = _subscribers.get(topic, None) + if not entry: + entry = maybe_init_subscriber(topic) + + if entry.subscriber: + raise RuntimeError( + f'subscriber for topic {topic} already set on {tractor.current_actor()}' + ) + + entry.subscriber = sub + entry.is_set.set() + + +def get_subscriber(topic: str = 'default') -> RingBufferSubscriber: + entry = _subscribers.get(topic, None) + if not entry or not entry.subscriber: + raise RuntimeError( + f'{tractor.current_actor()} tried to get subscriber' + 'but it\'s not set' + ) + + return entry.subscriber + + +async def wait_subscriber(topic: str) -> RingBufferSubscriber: + entry = maybe_init_subscriber(topic) + await entry.is_set.wait() + return entry.subscriber + + +@tractor.context +async def _add_sub_channel( + ctx: tractor.Context, + topic: str, + token: RBToken +): + subscriber = await wait_subscriber(topic) + await subscriber.add_channel(token) + + +@tractor.context +async def _remove_sub_channel( + ctx: tractor.Context, + topic: str, + ring_name: str +): + subscriber = await wait_subscriber(topic) + maybe_token = fdshare.maybe_get_fds(ring_name) + if maybe_token: + await subscriber.remove_channel(ring_name) + + +@acm +async def open_sub_channel_at( + actor_name: str, + token: RBToken, + topic: str = 'default', +): + async with tractor.find_actor(actor_name) as portal: + await portal.run(_add_sub_channel, topic=topic, token=token) + try: + yield + + except trio.Cancelled: + log.warning( + 'open_sub_channel_at got cancelled!\n' + f'\tactor_name = {actor_name}\n' + f'\ttoken = {token}\n' + ) + raise + + await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name) + + +''' +High level helpers to open publisher & subscriber +''' + + +@acm +async def open_ringbuf_publisher( + # name to distinguish this publisher + topic: str = 'default', + + # global batch size for channels + batch_size: int = 1, + + # messages before changing output channel + msgs_per_turn: int = 1, + + encoder: Encoder | None = None, + + # ensure subscriber receives in same order publisher sent + # causes it to use wrapped payloads which contain the og + # index + guarantee_order: bool = False, + + # on creation, set the `_publisher` global in order to use the provided + # tractor.context & helper utils for adding and removing new channels from + # remote actors + set_module_var: bool = True + +) -> AsyncContextManager[RingBufferPublisher]: + ''' + Open a new ringbuf publisher + + ''' + async with ( + trio.open_nursery(strict_exception_groups=False) as n, + RingBufferPublisher( + n, + batch_size=batch_size, + encoder=encoder, + ) as publisher + ): + if guarantee_order: + order_send_channel(publisher) + + if set_module_var: + set_publisher(topic, publisher) + + yield publisher + + n.cancel_scope.cancel() + + +@acm +async def open_ringbuf_subscriber( + # name to distinguish this subscriber + topic: str = 'default', + + decoder: Decoder | None = None, + + # expect indexed payloads and unwrap them in order + guarantee_order: bool = False, + + # on creation, set the `_subscriber` global in order to use the provided + # tractor.context & helper utils for adding and removing new channels from + # remote actors + set_module_var: bool = True +) -> AsyncContextManager[RingBufferPublisher]: + ''' + Open a new ringbuf subscriber + + ''' + async with ( + trio.open_nursery(strict_exception_groups=False) as n, + RingBufferSubscriber(n, decoder=decoder) as subscriber + ): + # maybe monkey patch `.receive` to use indexed payloads + if guarantee_order: + order_receive_channel(subscriber) + + # maybe set global module var for remote actor channel updates + if set_module_var: + set_subscriber(topic, subscriber) + + yield subscriber + + n.cancel_scope.cancel() diff --git a/tractor/ipc/_transport.py b/tractor/ipc/_transport.py index 6bfa5f6a..eb8ff3c9 100644 --- a/tractor/ipc/_transport.py +++ b/tractor/ipc/_transport.py @@ -78,7 +78,7 @@ class MsgTransport(Protocol): # eventual msg definition/types? # - https://docs.python.org/3/library/typing.html#typing.Protocol - stream: trio.SocketStream + stream: trio.abc.Stream drained: list[MsgType] address_type: ClassVar[Type[Address]] diff --git a/tractor/linux/__init__.py b/tractor/linux/__init__.py new file mode 100644 index 00000000..33526d14 --- /dev/null +++ b/tractor/linux/__init__.py @@ -0,0 +1,15 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/tractor/linux/_fdshare.py b/tractor/linux/_fdshare.py new file mode 100644 index 00000000..84681455 --- /dev/null +++ b/tractor/linux/_fdshare.py @@ -0,0 +1,316 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio. + +cpython impl: +https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138 +''' +import os +import array +import tempfile +from uuid import uuid4 +from pathlib import Path +from typing import AsyncContextManager +from contextlib import asynccontextmanager as acm + +import trio +import tractor +from trio import socket + + +log = tractor.log.get_logger(__name__) + + +class FDSharingError(Exception): + ... + + +@acm +async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]: + ''' + Async trio reimplementation of `multiprocessing.reduction.sendfds` + + https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142 + + It's implemented using an async context manager in order to simplyfy usage + with `tractor.context`s, we can open a context in a remote actor that uses + this acm inside of it, and uses `ctx.started()` to signal the original + caller actor to perform the `recv_fds` call. + + See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example. + ''' + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + await sock.bind(sock_path) + sock.listen(1) + + yield # socket is setup, ready for receiver connect + + # wait until receiver connects + conn, _ = await sock.accept() + + # setup int array for fds + fds = array.array('i', fds) + + # first byte of msg will be len of fds to send % 256, acting as a fd amount + # verification on `recv_fds` we refer to it as `check_byte` + msg = bytes([len(fds) % 256]) + + # send msg with custom SCM_RIGHTS type + await conn.sendmsg( + [msg], + [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] + ) + + # finally wait receiver ack + if await conn.recv(1) != b'A': + raise FDSharingError('did not receive acknowledgement of fd') + + conn.close() + sock.close() + os.unlink(sock_path) + + +async def recv_fds(sock_path: str, amount: int) -> tuple: + ''' + Async trio reimplementation of `multiprocessing.reduction.recvfds` + + https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150 + + It's equivalent to std just using `trio.open_unix_socket` for connecting and + changes on error handling. + + See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example. + ''' + stream = await trio.open_unix_socket(sock_path) + sock = stream.socket + + # prepare int array for fds + a = array.array('i') + bytes_size = a.itemsize * amount + + # receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds + msg, ancdata, flags, addr = await sock.recvmsg( + 1, socket.CMSG_SPACE(bytes_size) + ) + + # maybe failed to receive msg? + if not msg and not ancdata: + raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF') + + # send ack, std comment mentions this ack pattern was to get around an + # old macosx bug, but they are not sure if its necesary any more, in + # any case its not a bad pattern to keep + await sock.send(b'A') # Ack + + # expect to receive only one `ancdata` item + if len(ancdata) != 1: + raise FDSharingError( + f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}' + ) + + # unpack SCM_RIGHTS msg + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + + # check proper msg type + if cmsg_level != socket.SOL_SOCKET: + raise FDSharingError( + f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}' + ) + + if cmsg_type != socket.SCM_RIGHTS: + raise FDSharingError( + f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}' + ) + + # check proper data alignment + length = len(cmsg_data) + if length % a.itemsize != 0: + raise FDSharingError( + f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}' + ) + + # attempt to cast as int array + a.frombytes(cmsg_data) + + # validate length check byte + valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller + recvd_check_byte = msg[0] # actual received check byte + payload_check_byte = len(a) % 256 # check byte acording to received fd int array + + if recvd_check_byte != payload_check_byte: + raise FDSharingError( + 'Validation failed: received check byte ' + f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})' + ) + + if valid_check_byte != recvd_check_byte: + raise FDSharingError( + 'Validation failed: received check byte ' + f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})' + ) + + return tuple(a) + + +''' +Share FD actor module + +Add "tractor.linux._fdshare" to enabled modules on actors to allow sharing of +FDs with other actors. + +Use `share_fds` function to register a set of fds with a name, then other +actors can use `request_fds_from` function to retrieve the fds. + +Use `unshare_fds` to disable sharing of a set of FDs. + +''' + +FDType = tuple[int] + +_fds: dict[str, FDType] = {} + + +def maybe_get_fds(name: str) -> FDType | None: + ''' + Get registered FDs with a given name or return None + + ''' + return _fds.get(name, None) + + +def get_fds(name: str) -> FDType: + ''' + Get registered FDs with a given name or raise + ''' + fds = maybe_get_fds(name) + + if not fds: + raise RuntimeError(f'No FDs with name {name} found!') + + return fds + + +def share_fds( + name: str, + fds: tuple[int], +) -> None: + ''' + Register a set of fds to be shared under a given name. + + ''' + this_actor = tractor.current_actor() + if __name__ not in this_actor.enable_modules: + raise RuntimeError( + f'Tried to share FDs {fds} with name {name}, but ' + f'module {__name__} is not enabled in actor {this_actor.name}!' + ) + + maybe_fds = maybe_get_fds(name) + if maybe_fds: + raise RuntimeError(f'share FDs: {maybe_fds} already tied to name {name}') + + _fds[name] = fds + + +def unshare_fds(name: str) -> None: + ''' + Unregister a set of fds to disable sharing them. + + ''' + get_fds(name) # raise if not exists + + del _fds[name] + + +@tractor.context +async def _pass_fds( + ctx: tractor.Context, + name: str, + sock_path: str +) -> None: + ''' + Endpoint to request a set of FDs from current actor, will use `ctx.started` + to send original FDs, then `send_fds` will block until remote side finishes + the `recv_fds` call. + + ''' + # get fds or raise error + fds = get_fds(name) + + # start fd passing context using socket on `sock_path` + async with send_fds(fds, sock_path): + # send original fds through ctx.started + await ctx.started(fds) + + +async def request_fds_from( + actor_name: str, + fds_name: str +) -> FDType: + ''' + Use this function to retreive shared FDs from `actor_name`. + + ''' + this_actor = tractor.current_actor() + + # create a temporary path for the UDS sock + sock_path = str( + Path(tempfile.gettempdir()) + / + f'{fds_name}-from-{actor_name}-to-{this_actor.name}.sock' + ) + + # having a socket path length > 100 aprox can cause: + # OSError: AF_UNIX path too long + # https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html#tag_13_67_04 + + # attempt sock path creation with smaller names + if len(sock_path) > 100: + sock_path = str( + Path(tempfile.gettempdir()) + / + f'{fds_name}-to-{this_actor.name}.sock' + ) + + if len(sock_path) > 100: + # just use uuid4 + sock_path = str( + Path(tempfile.gettempdir()) + / + f'pass-fds-{uuid4()}.sock' + ) + + async with ( + tractor.find_actor(actor_name) as portal, + + portal.open_context( + _pass_fds, + name=fds_name, + sock_path=sock_path + ) as (ctx, fds_info), + ): + # get original FDs + og_fds = fds_info + + # retrieve copies of FDs + fds = await recv_fds(sock_path, len(og_fds)) + + log.info( + f'{this_actor.name} received fds: {og_fds} -> {fds}' + ) + + return fds diff --git a/tractor/ipc/_linux.py b/tractor/linux/eventfd.py similarity index 68% rename from tractor/ipc/_linux.py rename to tractor/linux/eventfd.py index 88d80d1c..1b00a190 100644 --- a/tractor/ipc/_linux.py +++ b/tractor/linux/eventfd.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . ''' -Linux specifics, for now we are only exposing EventFD +Expose libc eventfd APIs ''' import os @@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int: raise OSError(errno.errorcode[ffi.errno], 'close failed') +class EFDReadCancelled(Exception): + ... + + class EventFD: ''' Use a previously opened eventfd(2), meant to be used in @@ -124,26 +128,82 @@ class EventFD: self._fd: int = fd self._omode: str = omode self._fobj = None + self._cscope: trio.CancelScope | None = None + self._is_closed: bool = True + self._read_lock = trio.StrictFIFOLock() + + @property + def closed(self) -> bool: + return self._is_closed @property def fd(self) -> int | None: return self._fd def write(self, value: int) -> int: + if self.closed: + raise trio.ClosedResourceError + return write_eventfd(self._fd, value) async def read(self) -> int: - return await trio.to_thread.run_sync( - read_eventfd, self._fd, - abandon_on_cancel=True - ) + ''' + Async wrapper for `read_eventfd(self.fd)` + + `trio.to_thread.run_sync` is used, need to use a `trio.CancelScope` + in order to make it cancellable when `self.close()` is called. + + ''' + if self.closed: + raise trio.ClosedResourceError + + if self._read_lock.locked(): + raise trio.BusyResourceError + + async with self._read_lock: + self._cscope = trio.CancelScope() + with self._cscope: + try: + return await trio.to_thread.run_sync( + read_eventfd, self._fd, + abandon_on_cancel=True + ) + + except OSError as e: + if e.errno != errno.EBADF: + raise + + raise trio.BrokenResourceError + + if self._cscope.cancelled_caught: + raise EFDReadCancelled + + self._cscope = None + + def read_nowait(self) -> int: + ''' + Direct call to `read_eventfd(self.fd)`, unless `eventfd` was + opened with `EFD_NONBLOCK` its gonna block the thread. + + ''' + return read_eventfd(self._fd) def open(self): self._fobj = os.fdopen(self._fd, self._omode) + self._is_closed = False def close(self): if self._fobj: - self._fobj.close() + try: + self._fobj.close() + + except OSError: + ... + + if self._cscope: + self._cscope.cancel() + + self._is_closed = True def __enter__(self): self.open() diff --git a/tractor/trionics/__init__.py b/tractor/trionics/__init__.py index df9b6f26..97d03da7 100644 --- a/tractor/trionics/__init__.py +++ b/tractor/trionics/__init__.py @@ -32,3 +32,8 @@ from ._broadcast import ( from ._beg import ( collapse_eg as collapse_eg, ) + +from ._ordering import ( + order_send_channel as order_send_channel, + order_receive_channel as order_receive_channel +) diff --git a/tractor/trionics/_mngrs.py b/tractor/trionics/_mngrs.py index 9a5ed156..24b4fde8 100644 --- a/tractor/trionics/_mngrs.py +++ b/tractor/trionics/_mngrs.py @@ -70,7 +70,8 @@ async def maybe_open_nursery( yield nursery else: async with lib.open_nursery(**kwargs) as nursery: - nursery.cancel_scope.shield = shield + if lib == trio: + nursery.cancel_scope.shield = shield yield nursery diff --git a/tractor/trionics/_ordering.py b/tractor/trionics/_ordering.py new file mode 100644 index 00000000..0cc89b4b --- /dev/null +++ b/tractor/trionics/_ordering.py @@ -0,0 +1,108 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +Helpers to guarantee ordering of messages through a unordered channel + +''' +from __future__ import annotations +from heapq import ( + heappush, + heappop +) + +import trio +import msgspec + + +class OrderedPayload(msgspec.Struct, frozen=True): + index: int + payload: bytes + + @classmethod + def from_msg(cls, msg: bytes) -> OrderedPayload: + return msgspec.msgpack.decode(msg, type=OrderedPayload) + + def encode(self) -> bytes: + return msgspec.msgpack.encode(self) + + +def order_send_channel( + channel: trio.abc.SendChannel[bytes], + start_index: int = 0 +): + + next_index = start_index + send_lock = trio.StrictFIFOLock() + + channel._send = channel.send + channel._aclose = channel.aclose + + async def send(msg: bytes): + nonlocal next_index + async with send_lock: + await channel._send( + OrderedPayload( + index=next_index, + payload=msg + ).encode() + ) + next_index += 1 + + async def aclose(): + async with send_lock: + await channel._aclose() + + channel.send = send + channel.aclose = aclose + + +def order_receive_channel( + channel: trio.abc.ReceiveChannel[bytes], + start_index: int = 0 +): + next_index = start_index + pqueue = [] + + channel._receive = channel.receive + + def can_pop_next() -> bool: + return ( + len(pqueue) > 0 + and + pqueue[0][0] == next_index + ) + + async def drain_to_heap(): + while not can_pop_next(): + msg = await channel._receive() + msg = OrderedPayload.from_msg(msg) + heappush(pqueue, (msg.index, msg.payload)) + + def pop_next(): + nonlocal next_index + _, msg = heappop(pqueue) + next_index += 1 + return msg + + async def receive() -> bytes: + if can_pop_next(): + return pop_next() + + await drain_to_heap() + + return pop_next() + + channel.receive = receive