diff --git a/tests/test_eventfd.py b/tests/test_eventfd.py
new file mode 100644
index 00000000..3432048b
--- /dev/null
+++ b/tests/test_eventfd.py
@@ -0,0 +1,66 @@
+import trio
+import pytest
+from tractor.linux.eventfd import (
+ open_eventfd,
+ EFDReadCancelled,
+ EventFD
+)
+
+
+def test_read_cancellation():
+ '''
+ Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
+ is called.
+
+ '''
+ fd = open_eventfd()
+
+ async def bg_read(event: EventFD):
+ with pytest.raises(EFDReadCancelled):
+ await event.read()
+
+ async def main():
+ async with trio.open_nursery() as n:
+ with (
+ EventFD(fd, 'w') as event,
+ trio.fail_after(3)
+ ):
+ n.start_soon(bg_read, event)
+ await trio.sleep(0.2)
+ event.close()
+
+ trio.run(main)
+
+
+def test_read_trio_semantics():
+ '''
+ Ensure EventFD.read raises trio.ClosedResourceError and
+ trio.BusyResourceError.
+
+ '''
+
+ fd = open_eventfd()
+
+ async def bg_read(event: EventFD):
+ try:
+ await event.read()
+
+ except EFDReadCancelled:
+ ...
+
+ async def main():
+ async with trio.open_nursery() as n:
+
+ # start background read and attempt
+ # foreground read, should be busy
+ with EventFD(fd, 'w') as event:
+ n.start_soon(bg_read, event)
+ await trio.sleep(0.2)
+ with pytest.raises(trio.BusyResourceError):
+ await event.read()
+
+ # attempt read after close
+ with pytest.raises(trio.ClosedResourceError):
+ await event.read()
+
+ trio.run(main)
diff --git a/tests/test_ring_pubsub.py b/tests/test_ring_pubsub.py
new file mode 100644
index 00000000..b3b0dade
--- /dev/null
+++ b/tests/test_ring_pubsub.py
@@ -0,0 +1,185 @@
+from typing import AsyncContextManager
+from contextlib import asynccontextmanager as acm
+
+import trio
+import pytest
+import tractor
+
+from tractor.trionics import gather_contexts
+
+from tractor.ipc._ringbuf import open_ringbufs
+from tractor.ipc._ringbuf._pubsub import (
+ open_ringbuf_publisher,
+ open_ringbuf_subscriber,
+ get_publisher,
+ get_subscriber,
+ open_pub_channel_at,
+ open_sub_channel_at
+)
+
+
+log = tractor.log.get_console_log(level='info')
+
+
+@tractor.context
+async def publish_range(
+ ctx: tractor.Context,
+ size: int
+):
+ pub = get_publisher()
+ await ctx.started()
+ for i in range(size):
+ await pub.send(i.to_bytes(4))
+ log.info(f'sent {i}')
+
+ await pub.flush()
+
+ log.info('range done')
+
+
+@tractor.context
+async def subscribe_range(
+ ctx: tractor.Context,
+ size: int
+):
+ sub = get_subscriber()
+ await ctx.started()
+
+ for i in range(size):
+ recv = int.from_bytes(await sub.receive())
+ if recv != i:
+ raise AssertionError(
+ f'received: {recv} expected: {i}'
+ )
+
+ log.info(f'received: {recv}')
+
+ log.info('range done')
+
+
+@tractor.context
+async def subscriber_child(ctx: tractor.Context):
+ try:
+ async with open_ringbuf_subscriber(guarantee_order=True):
+ await ctx.started()
+ await trio.sleep_forever()
+
+ finally:
+ log.info('subscriber exit')
+
+
+@tractor.context
+async def publisher_child(
+ ctx: tractor.Context,
+ batch_size: int
+):
+ try:
+ async with open_ringbuf_publisher(
+ guarantee_order=True,
+ batch_size=batch_size
+ ):
+ await ctx.started()
+ await trio.sleep_forever()
+
+ finally:
+ log.info('publisher exit')
+
+
+@acm
+async def open_pubsub_test_actors(
+
+ ring_names: list[str],
+ size: int,
+ batch_size: int
+
+) -> AsyncContextManager[tuple[tractor.Portal, tractor.Portal]]:
+
+ with trio.fail_after(5):
+ async with tractor.open_nursery(
+ enable_modules=[
+ 'tractor.linux._fdshare'
+ ]
+ ) as an:
+ modules = [
+ __name__,
+ 'tractor.linux._fdshare',
+ 'tractor.ipc._ringbuf._pubsub'
+ ]
+ sub_portal = await an.start_actor(
+ 'sub',
+ enable_modules=modules
+ )
+ pub_portal = await an.start_actor(
+ 'pub',
+ enable_modules=modules
+ )
+
+ async with (
+ sub_portal.open_context(subscriber_child) as (long_rctx, _),
+ pub_portal.open_context(
+ publisher_child,
+ batch_size=batch_size
+ ) as (long_sctx, _),
+
+ open_ringbufs(ring_names) as tokens,
+
+ gather_contexts([
+ open_sub_channel_at('sub', ring)
+ for ring in tokens
+ ]),
+ gather_contexts([
+ open_pub_channel_at('pub', ring)
+ for ring in tokens
+ ]),
+ sub_portal.open_context(subscribe_range, size=size) as (rctx, _),
+ pub_portal.open_context(publish_range, size=size) as (sctx, _)
+ ):
+ yield
+
+ await rctx.wait_for_result()
+ await sctx.wait_for_result()
+
+ await long_sctx.cancel()
+ await long_rctx.cancel()
+
+ await an.cancel()
+
+
+@pytest.mark.parametrize(
+ ('ring_names', 'size', 'batch_size'),
+ [
+ (
+ ['ring-first'],
+ 100,
+ 1
+ ),
+ (
+ ['ring-first'],
+ 69,
+ 1
+ ),
+ (
+ [f'multi-ring-{i}' for i in range(3)],
+ 1000,
+ 100
+ ),
+ ],
+ ids=[
+ 'simple',
+ 'redo-simple',
+ 'multi-ring',
+ ]
+)
+def test_pubsub(
+ request,
+ ring_names: list[str],
+ size: int,
+ batch_size: int
+):
+ async def main():
+ async with open_pubsub_test_actors(
+ ring_names, size, batch_size
+ ):
+ ...
+
+ trio.run(main)
diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py
index 0d3b420b..1f6a1927 100644
--- a/tests/test_ringbuf.py
+++ b/tests/test_ringbuf.py
@@ -1,4 +1,5 @@
import time
+import hashlib
import trio
import pytest
@@ -6,36 +7,45 @@ import pytest
import tractor
from tractor.ipc._ringbuf import (
open_ringbuf,
+ open_ringbuf_pair,
+ attach_to_ringbuf_receiver,
+ attach_to_ringbuf_sender,
+ attach_to_ringbuf_channel,
RBToken,
- RingBuffSender,
- RingBuffReceiver
)
from tractor._testing.samples import (
- generate_sample_messages,
+ generate_single_byte_msgs,
+ RandomBytesGenerator
)
-# in case you don't want to melt your cores, uncomment dis!
-pytestmark = pytest.mark.skip
-
@tractor.context
async def child_read_shm(
ctx: tractor.Context,
- msg_amount: int,
token: RBToken,
- total_bytes: int,
-) -> None:
- recvd_bytes = 0
- await ctx.started()
- start_ts = time.time()
- async with RingBuffReceiver(token) as receiver:
- while recvd_bytes < total_bytes:
- msg = await receiver.receive_some()
- recvd_bytes += len(msg)
+) -> str:
+ '''
+ Sub-actor used in `test_ringbuf`.
- # make sure we dont hold any memoryviews
- # before the ctx manager aclose()
- msg = None
+ Attach to a ringbuf and receive all messages until end of stream.
+ Keep track of how many bytes received and also calculate
+ sha256 of the whole byte stream.
+
+ Calculate and print performance stats, finally return calculated
+ hash.
+
+ '''
+ await ctx.started()
+ print('reader started')
+ msg_amount = 0
+ recvd_bytes = 0
+ recvd_hash = hashlib.sha256()
+ start_ts = time.time()
+ async with attach_to_ringbuf_receiver(token) as receiver:
+ async for msg in receiver:
+ msg_amount += 1
+ recvd_hash.update(msg)
+ recvd_bytes += len(msg)
end_ts = time.time()
elapsed = end_ts - start_ts
@@ -44,6 +54,10 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
+ print(f'\treceived msgs: {msg_amount:,}')
+ print(f'\treceived bytes: {recvd_bytes:,}')
+
+ return recvd_hash.hexdigest()
@tractor.context
@@ -52,17 +66,37 @@ async def child_write_shm(
msg_amount: int,
rand_min: int,
rand_max: int,
- token: RBToken,
+ buf_size: int
) -> None:
- msgs, total_bytes = generate_sample_messages(
+ '''
+ Sub-actor used in `test_ringbuf`
+
+ Generate `msg_amount` payloads with
+ `random.randint(rand_min, rand_max)` random bytes at the end,
+ Calculate sha256 hash and send it to parent on `ctx.started`.
+
+ Attach to ringbuf and send all generated messages.
+
+ '''
+ rng = RandomBytesGenerator(
msg_amount,
rand_min=rand_min,
rand_max=rand_max,
)
- await ctx.started(total_bytes)
- async with RingBuffSender(token) as sender:
- for msg in msgs:
- await sender.send_all(msg)
+ async with (
+ open_ringbuf('test_ringbuf', buf_size=buf_size) as token,
+ attach_to_ringbuf_sender(token) as sender
+ ):
+ await ctx.started(token)
+ print('writer started')
+ for msg in rng:
+ await sender.send(msg)
+
+ if rng.msgs_generated % rng.recommended_log_interval == 0:
+ print(f'wrote {rng.msgs_generated} msgs')
+
+ print('writer exit')
+ return rng.hexdigest
@pytest.mark.parametrize(
@@ -89,84 +123,91 @@ def test_ringbuf(
rand_max: int,
buf_size: int
):
+ '''
+ - Open a new ring buf on root actor
+ - Open `child_write_shm` ctx in sub-actor which will generate a
+ random payload and send its hash on `ctx.started`, finally sending
+ the payload through the stream.
+ - Open `child_read_shm` ctx in sub-actor which will receive the
+ payload, calculate perf stats and return the hash.
+ - Compare both hashes
+
+ '''
async def main():
- with open_ringbuf(
- 'test_ringbuf',
- buf_size=buf_size
- ) as token:
- proc_kwargs = {
- 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
- }
+ async with tractor.open_nursery() as an:
+ send_p = await an.start_actor(
+ 'ring_sender',
+ enable_modules=[
+ __name__,
+ 'tractor.linux._fdshare'
+ ],
+ )
+ recv_p = await an.start_actor(
+ 'ring_receiver',
+ enable_modules=[
+ __name__,
+ 'tractor.linux._fdshare'
+ ],
+ )
+ async with (
+ send_p.open_context(
+ child_write_shm,
+ msg_amount=msg_amount,
+ rand_min=rand_min,
+ rand_max=rand_max,
+ buf_size=buf_size
+ ) as (sctx, token),
- common_kwargs = {
- 'msg_amount': msg_amount,
- 'token': token,
- }
- async with tractor.open_nursery() as an:
- send_p = await an.start_actor(
- 'ring_sender',
- enable_modules=[__name__],
- proc_kwargs=proc_kwargs
- )
- recv_p = await an.start_actor(
- 'ring_receiver',
- enable_modules=[__name__],
- proc_kwargs=proc_kwargs
- )
- async with (
- send_p.open_context(
- child_write_shm,
- rand_min=rand_min,
- rand_max=rand_max,
- **common_kwargs
- ) as (sctx, total_bytes),
- recv_p.open_context(
- child_read_shm,
- **common_kwargs,
- total_bytes=total_bytes,
- ) as (sctx, _sent),
- ):
- await recv_p.result()
+ recv_p.open_context(
+ child_read_shm,
+ token=token,
+ ) as (rctx, _),
+ ):
+ sent_hash = await sctx.result()
+ recvd_hash = await rctx.result()
- await send_p.cancel_actor()
- await recv_p.cancel_actor()
+ assert sent_hash == recvd_hash
+ await an.cancel()
trio.run(main)
@tractor.context
-async def child_blocked_receiver(
- ctx: tractor.Context,
- token: RBToken
-):
- async with RingBuffReceiver(token) as receiver:
- await ctx.started()
+async def child_blocked_receiver(ctx: tractor.Context):
+ async with (
+ open_ringbuf('test_ring_cancel_reader') as token,
+
+ attach_to_ringbuf_receiver(token) as receiver
+ ):
+ await ctx.started(token)
await receiver.receive_some()
-def test_ring_reader_cancel():
+def test_reader_cancel():
+ '''
+ Test that a receiver blocked on eventfd(2) read responds to
+ cancellation.
+
+ '''
async def main():
- with open_ringbuf('test_ring_cancel_reader') as token:
+ async with tractor.open_nursery() as an:
+ recv_p = await an.start_actor(
+ 'ring_blocked_receiver',
+ enable_modules=[
+ __name__,
+ 'tractor.linux._fdshare'
+ ],
+ )
async with (
- tractor.open_nursery() as an,
- RingBuffSender(token) as _sender,
+ recv_p.open_context(
+ child_blocked_receiver,
+ ) as (sctx, token),
+
+ attach_to_ringbuf_sender(token),
):
- recv_p = await an.start_actor(
- 'ring_blocked_receiver',
- enable_modules=[__name__],
- proc_kwargs={
- 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
- }
- )
- async with (
- recv_p.open_context(
- child_blocked_receiver,
- token=token
- ) as (sctx, _sent),
- ):
- await trio.sleep(1)
- await an.cancel()
+ await trio.sleep(.1)
+ await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
@@ -174,38 +215,166 @@ def test_ring_reader_cancel():
@tractor.context
-async def child_blocked_sender(
- ctx: tractor.Context,
- token: RBToken
-):
- async with RingBuffSender(token) as sender:
- await ctx.started()
+async def child_blocked_sender(ctx: tractor.Context):
+ async with (
+ open_ringbuf(
+ 'test_ring_cancel_sender',
+ buf_size=1
+ ) as token,
+
+ attach_to_ringbuf_sender(token) as sender
+ ):
+ await ctx.started(token)
await sender.send_all(b'this will wrap')
-def test_ring_sender_cancel():
+def test_sender_cancel():
+ '''
+ Test that a sender blocked on eventfd(2) read responds to
+ cancellation.
+
+ '''
async def main():
- with open_ringbuf(
- 'test_ring_cancel_sender',
- buf_size=1
- ) as token:
- async with tractor.open_nursery() as an:
- recv_p = await an.start_actor(
- 'ring_blocked_sender',
- enable_modules=[__name__],
- proc_kwargs={
- 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
- }
- )
- async with (
- recv_p.open_context(
- child_blocked_sender,
- token=token
- ) as (sctx, _sent),
- ):
- await trio.sleep(1)
- await an.cancel()
+ async with tractor.open_nursery() as an:
+ recv_p = await an.start_actor(
+ 'ring_blocked_sender',
+ enable_modules=[
+ __name__,
+ 'tractor.linux._fdshare'
+ ],
+ )
+ async with (
+ recv_p.open_context(
+ child_blocked_sender,
+ ) as (sctx, token),
+
+ attach_to_ringbuf_receiver(token)
+ ):
+ await trio.sleep(.1)
+ await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main)
+
+
+def test_receiver_max_bytes():
+ '''
+ Test that RingBuffReceiver.receive_some's max_bytes optional
+ argument works correctly, send a msg of size 100, then
+ force receive of messages with max_bytes == 1, wait until
+ 100 of these messages are received, then compare join of
+ msgs with original message
+
+ '''
+ msg = generate_single_byte_msgs(100)
+ msgs = []
+
+ rb_common = {
+ 'cleanup': False,
+ 'is_ipc': False
+ }
+
+ async def main():
+ async with (
+ open_ringbuf(
+ 'test_ringbuf_max_bytes',
+ buf_size=10,
+ is_ipc=False
+ ) as token,
+
+ trio.open_nursery() as n,
+
+ attach_to_ringbuf_sender(token, **rb_common) as sender,
+
+ attach_to_ringbuf_receiver(token, **rb_common) as receiver
+ ):
+ async def _send_and_close():
+ await sender.send_all(msg)
+ await sender.aclose()
+
+ n.start_soon(_send_and_close)
+ while len(msgs) < len(msg):
+ msg_part = await receiver.receive_some(max_bytes=1)
+ assert len(msg_part) == 1
+ msgs.append(msg_part)
+
+ trio.run(main)
+ assert msg == b''.join(msgs)
+
+
+@tractor.context
+async def child_channel_sender(
+ ctx: tractor.Context,
+ msg_amount_min: int,
+ msg_amount_max: int,
+ token_in: RBToken,
+ token_out: RBToken
+):
+ import random
+ rng = RandomBytesGenerator(
+ random.randint(msg_amount_min, msg_amount_max),
+ rand_min=256,
+ rand_max=1024,
+ )
+ async with attach_to_ringbuf_channel(
+ token_in,
+ token_out
+ ) as chan:
+ await ctx.started()
+ for msg in rng:
+ await chan.send(msg)
+
+ await chan.send(b'bye')
+ await chan.receive()
+ return rng.hexdigest
+
+
+def test_channel():
+
+ msg_amount_min = 100
+ msg_amount_max = 1000
+
+ mods = [
+ __name__,
+ 'tractor.linux._fdshare'
+ ]
+
+ async def main():
+ async with (
+ tractor.open_nursery(enable_modules=mods) as an,
+
+ open_ringbuf_pair(
+ 'test_ringbuf_transport'
+ ) as (send_token, recv_token),
+
+ attach_to_ringbuf_channel(send_token, recv_token) as chan,
+ ):
+ sender = await an.start_actor(
+ 'test_ringbuf_transport_sender',
+ enable_modules=mods,
+ )
+ async with (
+ sender.open_context(
+ child_channel_sender,
+ msg_amount_min=msg_amount_min,
+ msg_amount_max=msg_amount_max,
+ token_in=recv_token,
+ token_out=send_token
+ ) as (ctx, _),
+ ):
+ recvd_hash = hashlib.sha256()
+ async for msg in chan:
+ if msg == b'bye':
+ await chan.send(b'bye')
+ break
+
+ recvd_hash.update(msg)
+
+ sent_hash = await ctx.result()
+
+ assert recvd_hash.hexdigest() == sent_hash
+
+ await an.cancel()
+
+ trio.run(main)
diff --git a/tractor/_discovery.py b/tractor/_discovery.py
index fd3e4b1c..478538c6 100644
--- a/tractor/_discovery.py
+++ b/tractor/_discovery.py
@@ -121,9 +121,14 @@ def get_peer_by_name(
actor: Actor = current_actor()
server: IPCServer = actor.ipc_server
to_scan: dict[tuple, list[Channel]] = server._peers.copy()
- pchan: Channel|None = actor._parent_chan
- if pchan:
- to_scan[pchan.uid].append(pchan)
+
+ # TODO: is this ever needed? creates a duplicate channel on actor._peers
+ # when multiple find_actor calls are made to same actor from a single ctx
+ # which causes actor exit to hang waiting forever on
+ # `actor._no_more_peers.wait()` in `_runtime.async_main`
+ # pchan: Channel|None = actor._parent_chan
+ # if pchan:
+ # to_scan[pchan.uid].append(pchan)
for aid, chans in to_scan.items():
_, peer_name = aid
diff --git a/tractor/_testing/samples.py b/tractor/_testing/samples.py
index a87a22c4..fcf41dfa 100644
--- a/tractor/_testing/samples.py
+++ b/tractor/_testing/samples.py
@@ -1,35 +1,99 @@
-import os
-import random
+import hashlib
+import numpy as np
-def generate_sample_messages(
- amount: int,
- rand_min: int = 0,
- rand_max: int = 0,
- silent: bool = False
-) -> tuple[list[bytes], int]:
+def generate_single_byte_msgs(amount: int) -> bytes:
+ '''
+ Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
- msgs = []
- size = 0
+ '''
+ # array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
+ # up to 9->'9'(57).
+ arr = np.arange(amount, dtype=np.uint8) % 10
+ # move into ascii space
+ arr += 48
+ return arr.tobytes()
- if not silent:
- print(f'\ngenerating {amount} messages...')
- for i in range(amount):
- msg = f'[{i:08}]'.encode('utf-8')
+class RandomBytesGenerator:
+ '''
+ Generate bytes msgs for tests.
- if rand_max > 0:
- msg += os.urandom(
- random.randint(rand_min, rand_max))
+ messages will have the following format:
- size += len(msg)
+ b'[{i:08}]' + random_bytes
- msgs.append(msg)
+ so for message index 25:
- if not silent and i and i % 10_000 == 0:
- print(f'{i} generated')
+ b'[00000025]' + random_bytes
- if not silent:
- print(f'done, {size:,} bytes in total')
+ also generates sha256 hash of msgs.
- return msgs, size
+ '''
+
+ def __init__(
+ self,
+ amount: int,
+ rand_min: int = 0,
+ rand_max: int = 0
+ ):
+ if rand_max < rand_min:
+ raise ValueError('rand_max must be >= rand_min')
+
+ self._amount = amount
+ self._rand_min = rand_min
+ self._rand_max = rand_max
+ self._index = 0
+ self._hasher = hashlib.sha256()
+ self._total_bytes = 0
+
+ self._lengths = np.random.randint(
+ rand_min,
+ rand_max + 1,
+ size=amount,
+ dtype=np.int32
+ )
+
+ def __iter__(self):
+ return self
+
+ def __next__(self) -> bytes:
+ if self._index == self._amount:
+ raise StopIteration
+
+ header = f'[{self._index:08}]'.encode('utf-8')
+
+ length = int(self._lengths[self._index])
+ msg = header + np.random.bytes(length)
+
+ self._hasher.update(msg)
+ self._total_bytes += length
+ self._index += 1
+
+ return msg
+
+ @property
+ def hexdigest(self) -> str:
+ return self._hasher.hexdigest()
+
+ @property
+ def total_bytes(self) -> int:
+ return self._total_bytes
+
+ @property
+ def total_msgs(self) -> int:
+ return self._amount
+
+ @property
+ def msgs_generated(self) -> int:
+ return self._index
+
+ @property
+ def recommended_log_interval(self) -> int:
+ max_msg_size = 10 + self._rand_max
+
+ if max_msg_size <= 32 * 1024:
+ return 10_000
+
+ else:
+ return 1000
diff --git a/tractor/ipc/__init__.py b/tractor/ipc/__init__.py
index 2c6c3b5d..37c3c8ed 100644
--- a/tractor/ipc/__init__.py
+++ b/tractor/ipc/__init__.py
@@ -13,7 +13,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-
'''
A modular IPC layer supporting the power of cross-process SC!
diff --git a/tractor/ipc/_ringbuf.py b/tractor/ipc/_ringbuf.py
deleted file mode 100644
index 6337eea1..00000000
--- a/tractor/ipc/_ringbuf.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# tractor: structured concurrent "actors".
-# Copyright 2018-eternity Tyler Goodlet.
-
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-'''
-IPC Reliable RingBuffer implementation
-
-'''
-from __future__ import annotations
-from contextlib import contextmanager as cm
-from multiprocessing.shared_memory import SharedMemory
-
-import trio
-from msgspec import (
- Struct,
- to_builtins
-)
-
-from ._linux import (
- EFD_NONBLOCK,
- open_eventfd,
- EventFD
-)
-from ._mp_bs import disable_mantracker
-
-
-disable_mantracker()
-
-
-class RBToken(Struct, frozen=True):
- '''
- RingBuffer token contains necesary info to open the two
- eventfds and the shared memory
-
- '''
- shm_name: str
- write_eventfd: int
- wrap_eventfd: int
- buf_size: int
-
- def as_msg(self):
- return to_builtins(self)
-
- @classmethod
- def from_msg(cls, msg: dict) -> RBToken:
- if isinstance(msg, RBToken):
- return msg
-
- return RBToken(**msg)
-
-
-@cm
-def open_ringbuf(
- shm_name: str,
- buf_size: int = 10 * 1024,
- write_efd_flags: int = 0,
- wrap_efd_flags: int = 0
-) -> RBToken:
- shm = SharedMemory(
- name=shm_name,
- size=buf_size,
- create=True
- )
- try:
- token = RBToken(
- shm_name=shm_name,
- write_eventfd=open_eventfd(flags=write_efd_flags),
- wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
- buf_size=buf_size
- )
- yield token
-
- finally:
- shm.unlink()
-
-
-class RingBuffSender(trio.abc.SendStream):
- '''
- IPC Reliable Ring Buffer sender side implementation
-
- `eventfd(2)` is used for wrap around sync, and also to signal
- writes to the reader.
-
- '''
- def __init__(
- self,
- token: RBToken,
- start_ptr: int = 0,
- ):
- token = RBToken.from_msg(token)
- self._shm = SharedMemory(
- name=token.shm_name,
- size=token.buf_size,
- create=False
- )
- self._write_event = EventFD(token.write_eventfd, 'w')
- self._wrap_event = EventFD(token.wrap_eventfd, 'r')
- self._ptr = start_ptr
-
- @property
- def key(self) -> str:
- return self._shm.name
-
- @property
- def size(self) -> int:
- return self._shm.size
-
- @property
- def ptr(self) -> int:
- return self._ptr
-
- @property
- def write_fd(self) -> int:
- return self._write_event.fd
-
- @property
- def wrap_fd(self) -> int:
- return self._wrap_event.fd
-
- async def send_all(self, data: bytes | bytearray | memoryview):
- # while data is larger than the remaining buf
- target_ptr = self.ptr + len(data)
- while target_ptr > self.size:
- # write all bytes that fit
- remaining = self.size - self.ptr
- self._shm.buf[self.ptr:] = data[:remaining]
- # signal write and wait for reader wrap around
- self._write_event.write(remaining)
- await self._wrap_event.read()
-
- # wrap around and trim already written bytes
- self._ptr = 0
- data = data[remaining:]
- target_ptr = self._ptr + len(data)
-
- # remaining data fits on buffer
- self._shm.buf[self.ptr:target_ptr] = data
- self._write_event.write(len(data))
- self._ptr = target_ptr
-
- async def wait_send_all_might_not_block(self):
- raise NotImplementedError
-
- async def aclose(self):
- self._write_event.close()
- self._wrap_event.close()
- self._shm.close()
-
- async def __aenter__(self):
- self._write_event.open()
- self._wrap_event.open()
- return self
-
-
-class RingBuffReceiver(trio.abc.ReceiveStream):
- '''
- IPC Reliable Ring Buffer receiver side implementation
-
- `eventfd(2)` is used for wrap around sync, and also to signal
- writes to the reader.
-
- '''
- def __init__(
- self,
- token: RBToken,
- start_ptr: int = 0,
- flags: int = 0
- ):
- token = RBToken.from_msg(token)
- self._shm = SharedMemory(
- name=token.shm_name,
- size=token.buf_size,
- create=False
- )
- self._write_event = EventFD(token.write_eventfd, 'w')
- self._wrap_event = EventFD(token.wrap_eventfd, 'r')
- self._ptr = start_ptr
- self._flags = flags
-
- @property
- def key(self) -> str:
- return self._shm.name
-
- @property
- def size(self) -> int:
- return self._shm.size
-
- @property
- def ptr(self) -> int:
- return self._ptr
-
- @property
- def write_fd(self) -> int:
- return self._write_event.fd
-
- @property
- def wrap_fd(self) -> int:
- return self._wrap_event.fd
-
- async def receive_some(
- self,
- max_bytes: int | None = None,
- nb_timeout: float = 0.1
- ) -> memoryview:
- # if non blocking eventfd enabled, do polling
- # until next write, this allows signal handling
- if self._flags | EFD_NONBLOCK:
- delta = None
- while delta is None:
- try:
- delta = await self._write_event.read()
-
- except OSError as e:
- if e.errno == 'EAGAIN':
- continue
-
- raise e
-
- else:
- delta = await self._write_event.read()
-
- # fetch next segment and advance ptr
- next_ptr = self._ptr + delta
- segment = self._shm.buf[self._ptr:next_ptr]
- self._ptr = next_ptr
-
- if self.ptr == self.size:
- # reached the end, signal wrap around
- self._ptr = 0
- self._wrap_event.write(1)
-
- return segment
-
- async def aclose(self):
- self._write_event.close()
- self._wrap_event.close()
- self._shm.close()
-
- async def __aenter__(self):
- self._write_event.open()
- self._wrap_event.open()
- return self
diff --git a/tractor/ipc/_ringbuf/__init__.py b/tractor/ipc/_ringbuf/__init__.py
new file mode 100644
index 00000000..12943707
--- /dev/null
+++ b/tractor/ipc/_ringbuf/__init__.py
@@ -0,0 +1,1004 @@
+# tractor: structured concurrent "actors".
+# Copyright 2018-eternity Tyler Goodlet.
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+'''
+IPC Reliable RingBuffer implementation
+
+'''
+from __future__ import annotations
+import struct
+from typing import (
+ TypeVar,
+ ContextManager,
+ AsyncContextManager
+)
+from contextlib import (
+ contextmanager as cm,
+ asynccontextmanager as acm
+)
+from multiprocessing.shared_memory import SharedMemory
+
+import trio
+from msgspec import (
+ Struct,
+ to_builtins
+)
+from msgspec.msgpack import (
+ Encoder,
+ Decoder,
+)
+
+from tractor.log import get_logger
+from tractor._exceptions import (
+ InternalError
+)
+from tractor.ipc._mp_bs import disable_mantracker
+from tractor.linux._fdshare import (
+ share_fds,
+ unshare_fds,
+ request_fds_from
+)
+from tractor.linux.eventfd import (
+ open_eventfd,
+ EFDReadCancelled,
+ EventFD
+)
+from tractor._state import current_actor
+
+
+log = get_logger(__name__)
+
+
+disable_mantracker()
+
+_DEFAULT_RB_SIZE = 10 * 1024
+
+
+class RBToken(Struct, frozen=True):
+ '''
+ RingBuffer token contains necesary info to open resources of a ringbuf,
+ even in the case that ringbuf was not allocated by current actor.
+
+ '''
+ owner: str | None # if owner != `current_actor().name` we must use fdshare
+
+ shm_name: str
+
+ write_eventfd: int # used to signal writer ptr advance
+ wrap_eventfd: int # used to signal reader ready after wrap around
+ eof_eventfd: int # used to signal writer closed
+
+ buf_size: int # size in bytes of underlying shared memory buffer
+
+ def as_msg(self):
+ return to_builtins(self)
+
+ @classmethod
+ def from_msg(cls, msg: dict) -> RBToken:
+ if isinstance(msg, RBToken):
+ return msg
+
+ return RBToken(**msg)
+
+ @property
+ def fds(self) -> tuple[int, int, int]:
+ return (
+ self.write_eventfd,
+ self.wrap_eventfd,
+ self.eof_eventfd
+ )
+
+
+def alloc_ringbuf(
+ shm_name: str,
+ buf_size: int = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> tuple[SharedMemory, RBToken]:
+ '''
+ Allocate OS resources for a ringbuf.
+ '''
+ shm = SharedMemory(
+ name=shm_name,
+ size=buf_size,
+ create=True
+ )
+ token = RBToken(
+ owner=current_actor().name if is_ipc else None,
+ shm_name=shm_name,
+ write_eventfd=open_eventfd(),
+ wrap_eventfd=open_eventfd(),
+ eof_eventfd=open_eventfd(),
+ buf_size=buf_size
+ )
+
+ if is_ipc:
+ # register fds for sharing
+ share_fds(
+ shm_name,
+ token.fds,
+ )
+
+ return shm, token
+
+
+@cm
+def open_ringbuf_sync(
+ shm_name: str,
+ buf_size: int = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> ContextManager[RBToken]:
+ '''
+ Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
+ be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`,
+ post yield maybe unshare fds and unlink shared memory
+
+ '''
+ shm: SharedMemory | None = None
+ token: RBToken | None = None
+ try:
+ shm, token = alloc_ringbuf(
+ shm_name,
+ buf_size=buf_size,
+ is_ipc=is_ipc
+ )
+ yield token
+
+ finally:
+ if token and is_ipc:
+ unshare_fds(shm_name)
+
+ if shm:
+ shm.unlink()
+
+@acm
+async def open_ringbuf(
+ shm_name: str,
+ buf_size: int = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> AsyncContextManager[RBToken]:
+ '''
+ Helper to use `open_ringbuf_sync` inside an async with block.
+
+ '''
+ with open_ringbuf_sync(
+ shm_name,
+ buf_size=buf_size,
+ is_ipc=is_ipc
+ ) as token:
+ yield token
+
+
+@cm
+def open_ringbufs_sync(
+ shm_names: list[str],
+ buf_sizes: int | list[str] = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> ContextManager[tuple[RBToken]]:
+ '''
+ Handle resources for multiple ringbufs at once.
+
+ '''
+ # maybe convert single int into list
+ if isinstance(buf_sizes, int):
+ buf_size = [buf_sizes] * len(shm_names)
+
+ # ensure len(shm_names) == len(buf_sizes)
+ if (
+ isinstance(buf_sizes, list)
+ and
+ len(buf_sizes) != len(shm_names)
+ ):
+ raise ValueError(
+ 'Expected buf_size list to be same length as shm_names'
+ )
+
+ # allocate resources
+ rings: list[tuple[SharedMemory, RBToken]] = [
+ alloc_ringbuf(
+ shm_name,
+ buf_size=buf_size,
+ is_ipc=is_ipc
+ )
+ for shm_name, buf_size in zip(shm_names, buf_size)
+ ]
+
+ try:
+ yield tuple([token for _, token in rings])
+
+ finally:
+ # attempt fd unshare and shm unlink for each
+ for shm, token in rings:
+ if is_ipc:
+ try:
+ unshare_fds(token.shm_name)
+
+ except RuntimeError:
+ log.exception(f'while unsharing fds of {token}')
+
+ shm.unlink()
+
+
+@acm
+async def open_ringbufs(
+ shm_names: list[str],
+ buf_sizes: int | list[str] = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> AsyncContextManager[tuple[RBToken]]:
+ '''
+ Helper to use `open_ringbufs_sync` inside an async with block.
+
+ '''
+ with open_ringbufs_sync(
+ shm_names,
+ buf_sizes=buf_sizes,
+ is_ipc=is_ipc
+ ) as tokens:
+ yield tokens
+
+
+@cm
+def open_ringbuf_pair_sync(
+ shm_name: str,
+ buf_size: int = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> ContextManager[tuple(RBToken, RBToken)]:
+ '''
+ Handle resources for a ringbuf pair to be used for
+ bidirectional messaging.
+
+ '''
+ with open_ringbufs_sync(
+ [
+ f'{shm_name}.send',
+ f'{shm_name}.recv'
+ ],
+ buf_sizes=buf_size,
+ is_ipc=is_ipc
+ ) as tokens:
+ yield tokens
+
+
+@acm
+async def open_ringbuf_pair(
+ shm_name: str,
+ buf_size: int = _DEFAULT_RB_SIZE,
+ is_ipc: bool = True
+) -> AsyncContextManager[tuple[RBToken, RBToken]]:
+ '''
+ Helper to use `open_ringbuf_pair_sync` inside an async with block.
+
+ '''
+ with open_ringbuf_pair_sync(
+ shm_name,
+ buf_size=buf_size,
+ is_ipc=is_ipc
+ ) as tokens:
+ yield tokens
+
+
+Buffer = bytes | bytearray | memoryview
+
+
+'''
+IPC Reliable Ring Buffer
+
+`eventfd(2)` is used for wrap around sync, to signal writes to
+the reader and end of stream.
+
+In order to guarantee full messages are received, all bytes
+sent by `RingBufferSendChannel` are preceded with a 4 byte header
+which decodes into a uint32 indicating the actual size of the
+next full payload.
+
+'''
+
+
+PayloadT = TypeVar('PayloadT')
+
+
+class RingBufferSendChannel(trio.abc.SendChannel[PayloadT]):
+ '''
+ Ring Buffer sender side implementation
+
+ Do not use directly! manage with `attach_to_ringbuf_sender`
+ after having opened a ringbuf context with `open_ringbuf`.
+
+ Optional batch mode:
+
+ If `batch_size` > 1 messages wont get sent immediately but will be
+ stored until `batch_size` messages are pending, then it will send
+ them all at once.
+
+ `batch_size` can be changed dynamically but always call, `flush()`
+ right before.
+
+ '''
+ def __init__(
+ self,
+ token: RBToken,
+ batch_size: int = 1,
+ cleanup: bool = False,
+ encoder: Encoder | None = None
+ ):
+ self._token = RBToken.from_msg(token)
+ self.batch_size = batch_size
+
+ # ringbuf os resources
+ self._shm: SharedMemory | None = None
+ self._write_event = EventFD(self._token.write_eventfd, 'w')
+ self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
+ self._eof_event = EventFD(self._token.eof_eventfd, 'w')
+
+ # current write pointer
+ self._ptr: int = 0
+
+ # when `batch_size` > 1 store messages on `self._batch` and write them
+ # all, once `len(self._batch) == `batch_size`
+ self._batch: list[bytes] = []
+
+ # close shm & fds on exit?
+ self._cleanup: bool = cleanup
+
+ self._enc: Encoder | None = encoder
+
+ # have we closed this ringbuf?
+ # set to `False` on `.open()`
+ self._is_closed: bool = True
+
+ # ensure no concurrent `.send_all()` calls
+ self._send_all_lock = trio.StrictFIFOLock()
+
+ # ensure no concurrent `.send()` calls
+ self._send_lock = trio.StrictFIFOLock()
+
+ # ensure no concurrent `.flush()` calls
+ self._flush_lock = trio.StrictFIFOLock()
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
+
+ @property
+ def name(self) -> str:
+ if not self._shm:
+ raise ValueError('shared memory not initialized yet!')
+ return self._shm.name
+
+ @property
+ def size(self) -> int:
+ return self._token.buf_size
+
+ @property
+ def ptr(self) -> int:
+ return self._ptr
+
+ @property
+ def write_fd(self) -> int:
+ return self._write_event.fd
+
+ @property
+ def wrap_fd(self) -> int:
+ return self._wrap_event.fd
+
+ @property
+ def pending_msgs(self) -> int:
+ return len(self._batch)
+
+ @property
+ def must_flush(self) -> bool:
+ return self.pending_msgs >= self.batch_size
+
+ async def _wait_wrap(self):
+ await self._wrap_event.read()
+
+ async def send_all(self, data: Buffer):
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._send_all_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._send_all_lock:
+ # while data is larger than the remaining buf
+ target_ptr = self.ptr + len(data)
+ while target_ptr > self.size:
+ # write all bytes that fit
+ remaining = self.size - self.ptr
+ self._shm.buf[self.ptr:] = data[:remaining]
+ # signal write and wait for reader wrap around
+ self._write_event.write(remaining)
+ await self._wait_wrap()
+
+ # wrap around and trim already written bytes
+ self._ptr = 0
+ data = data[remaining:]
+ target_ptr = self._ptr + len(data)
+
+ # remaining data fits on buffer
+ self._shm.buf[self.ptr:target_ptr] = data
+ self._write_event.write(len(data))
+ self._ptr = target_ptr
+
+ async def wait_send_all_might_not_block(self):
+ return
+
+ async def flush(
+ self,
+ new_batch_size: int | None = None
+ ) -> None:
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ async with self._flush_lock:
+ for msg in self._batch:
+ await self.send_all(msg)
+
+ self._batch = []
+ if new_batch_size:
+ self.batch_size = new_batch_size
+
+ async def send(self, value: PayloadT) -> None:
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._send_lock.locked():
+ raise trio.BusyResourceError
+
+ raw_value: bytes = (
+ value
+ if isinstance(value, bytes)
+ else
+ self._enc.encode(value)
+ )
+
+ async with self._send_lock:
+ msg: bytes = struct.pack(" 0:
+ await self.flush()
+
+ await self.send_all(msg)
+ return
+
+ self._batch.append(msg)
+ if self.must_flush:
+ await self.flush()
+
+ def open(self):
+ try:
+ self._shm = SharedMemory(
+ name=self._token.shm_name,
+ size=self._token.buf_size,
+ create=False
+ )
+ self._write_event.open()
+ self._wrap_event.open()
+ self._eof_event.open()
+ self._is_closed = False
+
+ except Exception as e:
+ e.add_note(f'while opening sender for {self._token.as_msg()}')
+ raise e
+
+ def _close(self):
+ self._eof_event.write(
+ self._ptr if self._ptr > 0 else self.size
+ )
+
+ if self._cleanup:
+ self._write_event.close()
+ self._wrap_event.close()
+ self._eof_event.close()
+ self._shm.close()
+
+ self._is_closed = True
+
+ async def aclose(self):
+ if self.closed:
+ return
+
+ self._close()
+
+ async def __aenter__(self):
+ self.open()
+ return self
+
+
+class RingBufferReceiveChannel(trio.abc.ReceiveChannel[PayloadT]):
+ '''
+ Ring Buffer receiver side implementation
+
+ Do not use directly! manage with `attach_to_ringbuf_receiver`
+ after having opened a ringbuf context with `open_ringbuf`.
+
+ '''
+ def __init__(
+ self,
+ token: RBToken,
+ cleanup: bool = True,
+ decoder: Decoder | None = None
+ ):
+ self._token = RBToken.from_msg(token)
+
+ # ringbuf os resources
+ self._shm: SharedMemory | None = None
+ self._write_event = EventFD(self._token.write_eventfd, 'w')
+ self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
+ self._eof_event = EventFD(self._token.eof_eventfd, 'r')
+
+ # current read ptr
+ self._ptr: int = 0
+
+ # current write_ptr (max bytes we can read from buf)
+ self._write_ptr: int = 0
+
+ # end ptr is used when EOF is signaled, it will contain maximun
+ # readable position on buf
+ self._end_ptr: int = -1
+
+ # close shm & fds on exit?
+ self._cleanup: bool = cleanup
+
+ # have we closed this ringbuf?
+ # set to `False` on `.open()`
+ self._is_closed: bool = True
+
+ self._dec: Decoder | None = decoder
+
+ # ensure no concurrent `.receive_some()` calls
+ self._receive_some_lock = trio.StrictFIFOLock()
+
+ # ensure no concurrent `.receive_exactly()` calls
+ self._receive_exactly_lock = trio.StrictFIFOLock()
+
+ # ensure no concurrent `.receive()` calls
+ self._receive_lock = trio.StrictFIFOLock()
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
+
+ @property
+ def name(self) -> str:
+ if not self._shm:
+ raise ValueError('shared memory not initialized yet!')
+ return self._shm.name
+
+ @property
+ def size(self) -> int:
+ return self._token.buf_size
+
+ @property
+ def ptr(self) -> int:
+ return self._ptr
+
+ @property
+ def write_fd(self) -> int:
+ return self._write_event.fd
+
+ @property
+ def wrap_fd(self) -> int:
+ return self._wrap_event.fd
+
+ @property
+ def eof_was_signaled(self) -> bool:
+ return self._end_ptr != -1
+
+ async def _eof_monitor_task(self):
+ '''
+ Long running EOF event monitor, automatically run in bg by
+ `attach_to_ringbuf_receiver` context manager, if EOF event
+ is set its value will be the end pointer (highest valid
+ index to be read from buf, after setting the `self._end_ptr`
+ we close the write event which should cancel any blocked
+ `self._write_event.read()`s on it.
+
+ '''
+ try:
+ self._end_ptr = await self._eof_event.read()
+
+ except EFDReadCancelled:
+ ...
+
+ except trio.Cancelled:
+ ...
+
+ finally:
+ # closing write_event should trigger `EFDReadCancelled`
+ # on any pending read
+ self._write_event.close()
+
+ def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
+ '''
+ Try to receive any bytes we can without blocking or raise
+ `trio.WouldBlock`.
+
+ Returns b'' when no more bytes can be read (EOF signaled & read all).
+
+ '''
+ if max_bytes < 1:
+ raise ValueError("max_bytes must be >= 1")
+
+ # in case `end_ptr` is set that means eof was signaled.
+ # it will be >= `write_ptr`, use it for delta calc
+ highest_ptr = max(self._write_ptr, self._end_ptr)
+
+ delta = highest_ptr - self._ptr
+
+ # no more bytes to read
+ if delta == 0:
+ # if `end_ptr` is set that means we read all bytes before EOF
+ if self.eof_was_signaled:
+ return b''
+
+ # signal the need to wait on `write_event`
+ raise trio.WouldBlock
+
+ # dont overflow caller
+ delta = min(delta, max_bytes)
+
+ target_ptr = self._ptr + delta
+
+ # fetch next segment and advance ptr
+ segment = bytes(self._shm.buf[self._ptr:target_ptr])
+ self._ptr = target_ptr
+
+ if self._ptr == self.size:
+ # reached the end, signal wrap around
+ self._ptr = 0
+ self._write_ptr = 0
+ self._wrap_event.write(1)
+
+ return segment
+
+ async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
+ '''
+ Receive up to `max_bytes`, if no `max_bytes` is provided
+ a reasonable default is used.
+
+ Can return < max_bytes.
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._receive_some_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._receive_some_lock:
+ try:
+ # attempt direct read
+ return self.receive_nowait(max_bytes=max_bytes)
+
+ except trio.WouldBlock as e:
+ # we have read all we can, see if new data is available
+ if not self.eof_was_signaled:
+ # if we havent been signaled about EOF yet
+ try:
+ # wait next write and advance `write_ptr`
+ delta = await self._write_event.read()
+ self._write_ptr += delta
+ # yield lock and re-enter
+
+ except (
+ EFDReadCancelled, # read was cancelled with cscope
+ trio.Cancelled, # read got cancelled from outside
+ trio.BrokenResourceError # OSError EBADF happened while reading
+ ):
+ # while waiting for new data `self._write_event` was closed
+ try:
+ # if eof was signaled receive no wait will not raise
+ # trio.WouldBlock and will push remaining until EOF
+ return self.receive_nowait(max_bytes=max_bytes)
+
+ except trio.WouldBlock:
+ # eof was not signaled but `self._wrap_event` is closed
+ # this means send side closed without EOF signal
+ return b''
+
+ else:
+ # shouldnt happen because receive_nowait does not raise
+ # trio.WouldBlock when `end_ptr` is set
+ raise InternalError(
+ 'self._end_ptr is set but receive_nowait raised trio.WouldBlock'
+ ) from e
+
+ return await self.receive_some(max_bytes=max_bytes)
+
+ async def receive_exactly(self, num_bytes: int) -> bytes:
+ '''
+ Fetch bytes until we read exactly `num_bytes` or EOC.
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._receive_exactly_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._receive_exactly_lock:
+ payload = b''
+ while len(payload) < num_bytes:
+ remaining = num_bytes - len(payload)
+
+ new_bytes = await self.receive_some(
+ max_bytes=remaining
+ )
+
+ if new_bytes == b'':
+ break
+
+ payload += new_bytes
+
+ if payload == b'':
+ raise trio.EndOfChannel
+
+ return payload
+
+ async def receive(self, raw: bool = False) -> PayloadT:
+ '''
+ Receive a complete payload or raise EOC
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._receive_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._receive_lock:
+ header: bytes = await self.receive_exactly(4)
+ size: int
+ size, = struct.unpack(" tuple[bytes, PayloadT]:
+ if not self._dec:
+ raise RuntimeError('iter_raw_pair requires decoder')
+
+ while True:
+ try:
+ raw = await self.receive(raw=True)
+ yield raw, self._dec.decode(raw)
+
+ except trio.EndOfChannel:
+ break
+
+ def open(self):
+ try:
+ self._shm = SharedMemory(
+ name=self._token.shm_name,
+ size=self._token.buf_size,
+ create=False
+ )
+ self._write_event.open()
+ self._wrap_event.open()
+ self._eof_event.open()
+ self._is_closed = False
+
+ except Exception as e:
+ e.add_note(f'while opening receiver for {self._token.as_msg()}')
+ raise e
+
+ def close(self):
+ if self._cleanup:
+ self._write_event.close()
+ self._wrap_event.close()
+ self._eof_event.close()
+ self._shm.close()
+
+ self._is_closed = True
+
+ async def aclose(self):
+ if self.closed:
+ return
+
+ self.close()
+
+ async def __aenter__(self):
+ self.open()
+ return self
+
+
+async def _maybe_obtain_shared_resources(token: RBToken):
+ token = RBToken.from_msg(token)
+
+ # maybe token wasn't allocated by current actor
+ if token.owner != current_actor().name:
+ # use fdshare module to retrieve a copy of the FDs
+ fds = await request_fds_from(
+ token.owner,
+ token.shm_name
+ )
+ write, wrap, eof = fds
+ # rebuild token using FDs copies
+ token = RBToken(
+ owner=token.owner,
+ shm_name=token.shm_name,
+ write_eventfd=write,
+ wrap_eventfd=wrap,
+ eof_eventfd=eof,
+ buf_size=token.buf_size
+ )
+
+ return token
+
+@acm
+async def attach_to_ringbuf_receiver(
+
+ token: RBToken,
+ cleanup: bool = True,
+ decoder: Decoder | None = None,
+ is_ipc: bool = True
+
+) -> AsyncContextManager[RingBufferReceiveChannel]:
+ '''
+ Attach a RingBufferReceiveChannel from a previously opened
+ RBToken.
+
+ Requires tractor runtime to be up in order to support opening a ringbuf
+ originally allocated by a different actor.
+
+ Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
+ '''
+ if is_ipc:
+ token = await _maybe_obtain_shared_resources(token)
+
+ async with (
+ trio.open_nursery(strict_exception_groups=False) as n,
+ RingBufferReceiveChannel(
+ token,
+ cleanup=cleanup,
+ decoder=decoder
+ ) as receiver
+ ):
+ n.start_soon(receiver._eof_monitor_task)
+ yield receiver
+
+
+@acm
+async def attach_to_ringbuf_sender(
+
+ token: RBToken,
+ batch_size: int = 1,
+ cleanup: bool = True,
+ encoder: Encoder | None = None,
+ is_ipc: bool = True
+
+) -> AsyncContextManager[RingBufferSendChannel]:
+ '''
+ Attach a RingBufferSendChannel from a previously opened
+ RBToken.
+
+ Requires tractor runtime to be up in order to support opening a ringbuf
+ originally allocated by a different actor.
+
+ '''
+ if is_ipc:
+ token = await _maybe_obtain_shared_resources(token)
+
+ async with RingBufferSendChannel(
+ token,
+ batch_size=batch_size,
+ cleanup=cleanup,
+ encoder=encoder
+ ) as sender:
+ yield sender
+
+
+class RingBufferChannel(trio.abc.Channel[bytes]):
+ '''
+ Combine `RingBufferSendChannel` and `RingBufferReceiveChannel`
+ in order to expose the bidirectional `trio.abc.Channel` API.
+
+ '''
+ def __init__(
+ self,
+ sender: RingBufferSendChannel,
+ receiver: RingBufferReceiveChannel
+ ):
+ self._sender = sender
+ self._receiver = receiver
+
+ @property
+ def batch_size(self) -> int:
+ return self._sender.batch_size
+
+ @batch_size.setter
+ def batch_size(self, value: int) -> None:
+ self._sender.batch_size = value
+
+ @property
+ def pending_msgs(self) -> int:
+ return self._sender.pending_msgs
+
+ async def send_all(self, value: bytes) -> None:
+ await self._sender.send_all(value)
+
+ async def wait_send_all_might_not_block(self):
+ await self._sender.wait_send_all_might_not_block()
+
+ async def flush(
+ self,
+ new_batch_size: int | None = None
+ ) -> None:
+ await self._sender.flush(new_batch_size=new_batch_size)
+
+ async def send(self, value: bytes) -> None:
+ await self._sender.send(value)
+
+ async def send_eof(self) -> None:
+ await self._sender.send_eof()
+
+ def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
+ return self._receiver.receive_nowait(max_bytes=max_bytes)
+
+ async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
+ return await self._receiver.receive_some(max_bytes=max_bytes)
+
+ async def receive_exactly(self, num_bytes: int) -> bytes:
+ return await self._receiver.receive_exactly(num_bytes)
+
+ async def receive(self) -> bytes:
+ return await self._receiver.receive()
+
+ async def aclose(self):
+ await self._receiver.aclose()
+ await self._sender.aclose()
+
+
+@acm
+async def attach_to_ringbuf_channel(
+ token_in: RBToken,
+ token_out: RBToken,
+ batch_size: int = 1,
+ cleanup_in: bool = True,
+ cleanup_out: bool = True,
+ encoder: Encoder | None = None,
+ decoder: Decoder | None = None,
+ sender_ipc: bool = True,
+ receiver_ipc: bool = True
+) -> AsyncContextManager[trio.StapledStream]:
+ '''
+ Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
+
+ '''
+ async with (
+ attach_to_ringbuf_receiver(
+ token_in,
+ cleanup=cleanup_in,
+ decoder=decoder,
+ is_ipc=receiver_ipc
+ ) as receiver,
+ attach_to_ringbuf_sender(
+ token_out,
+ batch_size=batch_size,
+ cleanup=cleanup_out,
+ encoder=encoder,
+ is_ipc=sender_ipc
+ ) as sender,
+ ):
+ yield RingBufferChannel(sender, receiver)
diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py
new file mode 100644
index 00000000..c85de9ca
--- /dev/null
+++ b/tractor/ipc/_ringbuf/_pubsub.py
@@ -0,0 +1,834 @@
+# tractor: structured concurrent "actors".
+# Copyright 2018-eternity Tyler Goodlet.
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+'''
+Ring buffer ipc publish-subscribe mechanism brokered by ringd
+can dynamically add new outputs (publisher) or inputs (subscriber)
+'''
+from typing import (
+ TypeVar,
+ Generic,
+ Callable,
+ Awaitable,
+ AsyncContextManager
+)
+from functools import partial
+from contextlib import asynccontextmanager as acm
+from dataclasses import dataclass
+
+import trio
+import tractor
+
+from msgspec.msgpack import (
+ Encoder,
+ Decoder
+)
+
+from tractor.ipc._ringbuf import (
+ RBToken,
+ PayloadT,
+ RingBufferSendChannel,
+ RingBufferReceiveChannel,
+ attach_to_ringbuf_sender,
+ attach_to_ringbuf_receiver
+)
+
+from tractor.trionics import (
+ order_send_channel,
+ order_receive_channel
+)
+
+import tractor.linux._fdshare as fdshare
+
+
+log = tractor.log.get_logger(__name__)
+
+
+ChannelType = TypeVar('ChannelType')
+
+
+@dataclass
+class ChannelInfo:
+ token: RBToken
+ channel: ChannelType
+ cancel_scope: trio.CancelScope
+ teardown: trio.Event
+
+
+class ChannelManager(Generic[ChannelType]):
+ '''
+ Helper for managing channel resources and their handler tasks with
+ cancellation, add or remove channels dynamically!
+
+ '''
+
+ def __init__(
+ self,
+ # nursery used to spawn channel handler tasks
+ n: trio.Nursery,
+
+ # acm will be used for setup & teardown of channel resources
+ open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
+
+ # long running bg task to handle channel
+ channel_task: Callable[..., Awaitable[None]]
+ ):
+ self._n = n
+ self._open_channel = open_channel_acm
+ self._channel_task = channel_task
+
+ # signal when a new channel conects and we previously had none
+ self._connect_event = trio.Event()
+
+ # store channel runtime variables
+ self._channels: list[ChannelInfo] = []
+
+ self._is_closed: bool = True
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
+
+ @property
+ def channels(self) -> list[ChannelInfo]:
+ return self._channels
+
+ async def _channel_handler_task(
+ self,
+ token: RBToken,
+ task_status=trio.TASK_STATUS_IGNORED,
+ **kwargs
+ ):
+ '''
+ Open channel resources, add to internal data structures, signal channel
+ connect through trio.Event, and run `channel_task` with cancel scope,
+ and finally, maybe remove channel from internal data structures.
+
+ Spawned by `add_channel` function, lock is held from begining of fn
+ until `task_status.started()` call.
+
+ kwargs are proxied to `self._open_channel` acm.
+ '''
+ async with self._open_channel(
+ token,
+ **kwargs
+ ) as chan:
+ cancel_scope = trio.CancelScope()
+ info = ChannelInfo(
+ token=token,
+ channel=chan,
+ cancel_scope=cancel_scope,
+ teardown=trio.Event()
+ )
+ self._channels.append(info)
+
+ if len(self) == 1:
+ self._connect_event.set()
+
+ task_status.started()
+
+ with cancel_scope:
+ await self._channel_task(info)
+
+ self._maybe_destroy_channel(token.shm_name)
+
+ def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
+ '''
+ Given a channel name maybe return its index and value from
+ internal _channels list.
+
+ Only use after acquiring lock.
+ '''
+ for entry in enumerate(self._channels):
+ i, info = entry
+ if info.token.shm_name == name:
+ return entry
+
+ return None
+
+
+ def _maybe_destroy_channel(self, name: str):
+ '''
+ If channel exists cancel its scope and remove from internal
+ _channels list.
+
+ '''
+ maybe_entry = self._find_channel(name)
+ if maybe_entry:
+ i, info = maybe_entry
+ info.cancel_scope.cancel()
+ info.teardown.set()
+ del self._channels[i]
+
+ async def add_channel(
+ self,
+ token: RBToken,
+ **kwargs
+ ):
+ '''
+ Add a new channel to be handled
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ await self._n.start(partial(
+ self._channel_handler_task,
+ RBToken.from_msg(token),
+ **kwargs
+ ))
+
+ async def remove_channel(self, name: str):
+ '''
+ Remove a channel and stop its handling
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ maybe_entry = self._find_channel(name)
+ if not maybe_entry:
+ # return
+ raise RuntimeError(
+ f'tried to remove channel {name} but if does not exist'
+ )
+
+ i, info = maybe_entry
+ self._maybe_destroy_channel(name)
+
+ await info.teardown.wait()
+
+ # if that was last channel reset connect event
+ if len(self) == 0:
+ self._connect_event = trio.Event()
+
+ async def wait_for_channel(self):
+ '''
+ Wait until at least one channel added
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ await self._connect_event.wait()
+ self._connect_event = trio.Event()
+
+ def __len__(self) -> int:
+ return len(self._channels)
+
+ def __getitem__(self, name: str):
+ maybe_entry = self._find_channel(name)
+ if maybe_entry:
+ _, info = maybe_entry
+ return info
+
+ raise KeyError(f'Channel {name} not found!')
+
+ def open(self):
+ self._is_closed = False
+
+ async def close(self) -> None:
+ if self.closed:
+ log.warning('tried to close ChannelManager but its already closed...')
+ return
+
+ for info in self._channels:
+ if info.channel.closed:
+ continue
+
+ await info.channel.aclose()
+ await self.remove_channel(info.token.shm_name)
+
+ self._is_closed = True
+
+
+'''
+Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
+
+'''
+
+
+class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
+ '''
+ Use ChannelManager to create a multi ringbuf round robin sender that can
+ dynamically add or remove more outputs.
+
+ Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
+ lifecycle.
+
+ '''
+ def __init__(
+ self,
+ n: trio.Nursery,
+
+ # amount of msgs to each ring before switching turns
+ msgs_per_turn: int = 1,
+
+ # global batch size for all channels
+ batch_size: int = 1,
+
+ encoder: Encoder | None = None
+ ):
+ self._batch_size: int = batch_size
+ self.msgs_per_turn = msgs_per_turn
+ self._enc = encoder
+
+ # helper to manage acms + long running tasks
+ self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
+ n,
+ self._open_channel,
+ self._channel_task
+ )
+
+ # ensure no concurrent `.send()` calls
+ self._send_lock = trio.StrictFIFOLock()
+
+ # index of channel to be used for next send
+ self._next_turn: int = 0
+ # amount of messages sent this turn
+ self._turn_msgs: int = 0
+ # have we closed this publisher?
+ # set to `False` on `.__aenter__()`
+ self._is_closed: bool = True
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
+
+ @property
+ def batch_size(self) -> int:
+ return self._batch_size
+
+ @batch_size.setter
+ def batch_size(self, value: int) -> None:
+ for info in self.channels:
+ info.channel.batch_size = value
+
+ @property
+ def channels(self) -> list[ChannelInfo]:
+ return self._chanmngr.channels
+
+ def _get_next_turn(self) -> int:
+ '''
+ Maybe switch turn and reset self._turn_msgs or just increment it.
+ Return current turn
+ '''
+ if self._turn_msgs == self.msgs_per_turn:
+ self._turn_msgs = 0
+ self._next_turn += 1
+
+ if self._next_turn >= len(self.channels):
+ self._next_turn = 0
+
+ else:
+ self._turn_msgs += 1
+
+ return self._next_turn
+
+ def get_channel(self, name: str) -> ChannelInfo:
+ '''
+ Get underlying ChannelInfo from name
+
+ '''
+ return self._chanmngr[name]
+
+ async def add_channel(
+ self,
+ token: RBToken,
+ ):
+ await self._chanmngr.add_channel(token)
+
+ async def remove_channel(self, name: str):
+ await self._chanmngr.remove_channel(name)
+
+ @acm
+ async def _open_channel(
+
+ self,
+ token: RBToken
+
+ ) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
+ async with attach_to_ringbuf_sender(
+ token,
+ batch_size=self._batch_size,
+ encoder=self._enc
+ ) as ring:
+ yield ring
+
+ async def _channel_task(self, info: ChannelInfo) -> None:
+ '''
+ Wait forever until channel cancellation
+
+ '''
+ await trio.sleep_forever()
+
+ async def send(self, msg: bytes):
+ '''
+ If no output channels connected, wait until one, then fetch the next
+ channel based on turn.
+
+ Needs to acquire `self._send_lock` to ensure no concurrent calls.
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._send_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._send_lock:
+ # wait at least one decoder connected
+ if len(self.channels) == 0:
+ await self._chanmngr.wait_for_channel()
+
+ turn = self._get_next_turn()
+
+ info = self.channels[turn]
+ await info.channel.send(msg)
+
+ async def broadcast(self, msg: PayloadT):
+ '''
+ Send a msg to all channels, if no channels connected, does nothing.
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ for info in self.channels:
+ await info.channel.send(msg)
+
+ async def flush(self, new_batch_size: int | None = None):
+ for info in self.channels:
+ try:
+ await info.channel.flush(new_batch_size=new_batch_size)
+
+ except trio.ClosedResourceError:
+ ...
+
+ async def __aenter__(self):
+ self._is_closed = False
+ self._chanmngr.open()
+ return self
+
+ async def aclose(self) -> None:
+ if self.closed:
+ log.warning('tried to close RingBufferPublisher but its already closed...')
+ return
+
+ await self._chanmngr.close()
+
+ self._is_closed = True
+
+
+class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
+ '''
+ Use ChannelManager to create a multi ringbuf receiver that can
+ dynamically add or remove more inputs and combine all into a single output.
+
+ In order for `self.receive` messages to be returned in order, publisher
+ will send all payloads as `OrderedPayload` msgpack encoded msgs, this
+ allows our channel handler tasks to just stash the out of order payloads
+ inside `self._pending_payloads` and if a in order payload is available
+ signal through `self._new_payload_event`.
+
+ On `self.receive` we wait until at least one channel is connected, then if
+ an in order payload is pending, we pop and return it, in case no in order
+ payload is available wait until next `self._new_payload_event.set()`.
+
+ '''
+ def __init__(
+ self,
+ n: trio.Nursery,
+
+ decoder: Decoder | None = None
+ ):
+ self._dec = decoder
+ self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
+ n,
+ self._open_channel,
+ self._channel_task
+ )
+
+ self._schan, self._rchan = trio.open_memory_channel(0)
+
+ self._is_closed: bool = True
+
+ self._receive_lock = trio.StrictFIFOLock()
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
+
+ @property
+ def channels(self) -> list[ChannelInfo]:
+ return self._chanmngr.channels
+
+ def get_channel(self, name: str):
+ return self._chanmngr[name]
+
+ async def add_channel(
+ self,
+ token: RBToken
+ ):
+ await self._chanmngr.add_channel(token)
+
+ async def remove_channel(self, name: str):
+ await self._chanmngr.remove_channel(name)
+
+ @acm
+ async def _open_channel(
+
+ self,
+ token: RBToken
+
+ ) -> AsyncContextManager[RingBufferSendChannel]:
+ async with attach_to_ringbuf_receiver(
+ token,
+ decoder=self._dec
+ ) as ring:
+ yield ring
+
+ async def _channel_task(self, info: ChannelInfo) -> None:
+ '''
+ Iterate over receive channel messages, decode them as `OrderedPayload`s
+ and stash them in `self._pending_payloads`, in case we can pop next in
+ order payload, signal through setting `self._new_payload_event`.
+
+ '''
+ while True:
+ try:
+ msg = await info.channel.receive()
+ await self._schan.send(msg)
+
+ except tractor.linux.eventfd.EFDReadCancelled as e:
+ # when channel gets removed while we are doing a receive
+ log.exception(e)
+ break
+
+ except trio.EndOfChannel:
+ break
+
+ except trio.ClosedResourceError:
+ break
+
+ async def receive(self) -> PayloadT:
+ '''
+ Receive next in order msg
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._receive_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._receive_lock:
+ return await self._rchan.receive()
+
+ async def __aenter__(self):
+ self._is_closed = False
+ self._chanmngr.open()
+ return self
+
+ async def aclose(self) -> None:
+ if self.closed:
+ return
+
+ await self._chanmngr.close()
+ await self._schan.aclose()
+ await self._rchan.aclose()
+
+ self._is_closed = True
+
+
+'''
+Actor module for managing publisher & subscriber channels remotely through
+`tractor.context` rpc
+'''
+
+@dataclass
+class PublisherEntry:
+ publisher: RingBufferPublisher | None = None
+ is_set: trio.Event = trio.Event()
+
+
+_publishers: dict[str, PublisherEntry] = {}
+
+
+def maybe_init_publisher(topic: str) -> PublisherEntry:
+ entry = _publishers.get(topic, None)
+ if not entry:
+ entry = PublisherEntry()
+ _publishers[topic] = entry
+
+ return entry
+
+
+def set_publisher(topic: str, pub: RingBufferPublisher):
+ global _publishers
+
+ entry = _publishers.get(topic, None)
+ if not entry:
+ entry = maybe_init_publisher(topic)
+
+ if entry.publisher:
+ raise RuntimeError(
+ f'publisher for topic {topic} already set on {tractor.current_actor()}'
+ )
+
+ entry.publisher = pub
+ entry.is_set.set()
+
+
+def get_publisher(topic: str = 'default') -> RingBufferPublisher:
+ entry = _publishers.get(topic, None)
+ if not entry or not entry.publisher:
+ raise RuntimeError(
+ f'{tractor.current_actor()} tried to get publisher'
+ 'but it\'s not set'
+ )
+
+ return entry.publisher
+
+
+async def wait_publisher(topic: str) -> RingBufferPublisher:
+ entry = maybe_init_publisher(topic)
+ await entry.is_set.wait()
+ return entry.publisher
+
+
+@tractor.context
+async def _add_pub_channel(
+ ctx: tractor.Context,
+ topic: str,
+ token: RBToken
+):
+ publisher = await wait_publisher(topic)
+ await publisher.add_channel(token)
+
+
+@tractor.context
+async def _remove_pub_channel(
+ ctx: tractor.Context,
+ topic: str,
+ ring_name: str
+):
+ publisher = await wait_publisher(topic)
+ maybe_token = fdshare.maybe_get_fds(ring_name)
+ if maybe_token:
+ await publisher.remove_channel(ring_name)
+
+
+@acm
+async def open_pub_channel_at(
+ actor_name: str,
+ token: RBToken,
+ topic: str = 'default',
+):
+ async with tractor.find_actor(actor_name) as portal:
+ await portal.run(_add_pub_channel, topic=topic, token=token)
+ try:
+ yield
+
+ except trio.Cancelled:
+ log.warning(
+ 'open_pub_channel_at got cancelled!\n'
+ f'\tactor_name = {actor_name}\n'
+ f'\ttoken = {token}\n'
+ )
+ raise
+
+ await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
+
+
+@dataclass
+class SubscriberEntry:
+ subscriber: RingBufferSubscriber | None = None
+ is_set: trio.Event = trio.Event()
+
+
+_subscribers: dict[str, SubscriberEntry] = {}
+
+
+def maybe_init_subscriber(topic: str) -> SubscriberEntry:
+ entry = _subscribers.get(topic, None)
+ if not entry:
+ entry = SubscriberEntry()
+ _subscribers[topic] = entry
+
+ return entry
+
+
+def set_subscriber(topic: str, sub: RingBufferSubscriber):
+ global _subscribers
+
+ entry = _subscribers.get(topic, None)
+ if not entry:
+ entry = maybe_init_subscriber(topic)
+
+ if entry.subscriber:
+ raise RuntimeError(
+ f'subscriber for topic {topic} already set on {tractor.current_actor()}'
+ )
+
+ entry.subscriber = sub
+ entry.is_set.set()
+
+
+def get_subscriber(topic: str = 'default') -> RingBufferSubscriber:
+ entry = _subscribers.get(topic, None)
+ if not entry or not entry.subscriber:
+ raise RuntimeError(
+ f'{tractor.current_actor()} tried to get subscriber'
+ 'but it\'s not set'
+ )
+
+ return entry.subscriber
+
+
+async def wait_subscriber(topic: str) -> RingBufferSubscriber:
+ entry = maybe_init_subscriber(topic)
+ await entry.is_set.wait()
+ return entry.subscriber
+
+
+@tractor.context
+async def _add_sub_channel(
+ ctx: tractor.Context,
+ topic: str,
+ token: RBToken
+):
+ subscriber = await wait_subscriber(topic)
+ await subscriber.add_channel(token)
+
+
+@tractor.context
+async def _remove_sub_channel(
+ ctx: tractor.Context,
+ topic: str,
+ ring_name: str
+):
+ subscriber = await wait_subscriber(topic)
+ maybe_token = fdshare.maybe_get_fds(ring_name)
+ if maybe_token:
+ await subscriber.remove_channel(ring_name)
+
+
+@acm
+async def open_sub_channel_at(
+ actor_name: str,
+ token: RBToken,
+ topic: str = 'default',
+):
+ async with tractor.find_actor(actor_name) as portal:
+ await portal.run(_add_sub_channel, topic=topic, token=token)
+ try:
+ yield
+
+ except trio.Cancelled:
+ log.warning(
+ 'open_sub_channel_at got cancelled!\n'
+ f'\tactor_name = {actor_name}\n'
+ f'\ttoken = {token}\n'
+ )
+ raise
+
+ await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
+
+
+'''
+High level helpers to open publisher & subscriber
+'''
+
+
+@acm
+async def open_ringbuf_publisher(
+ # name to distinguish this publisher
+ topic: str = 'default',
+
+ # global batch size for channels
+ batch_size: int = 1,
+
+ # messages before changing output channel
+ msgs_per_turn: int = 1,
+
+ encoder: Encoder | None = None,
+
+ # ensure subscriber receives in same order publisher sent
+ # causes it to use wrapped payloads which contain the og
+ # index
+ guarantee_order: bool = False,
+
+ # on creation, set the `_publisher` global in order to use the provided
+ # tractor.context & helper utils for adding and removing new channels from
+ # remote actors
+ set_module_var: bool = True
+
+) -> AsyncContextManager[RingBufferPublisher]:
+ '''
+ Open a new ringbuf publisher
+
+ '''
+ async with (
+ trio.open_nursery(strict_exception_groups=False) as n,
+ RingBufferPublisher(
+ n,
+ batch_size=batch_size,
+ encoder=encoder,
+ ) as publisher
+ ):
+ if guarantee_order:
+ order_send_channel(publisher)
+
+ if set_module_var:
+ set_publisher(topic, publisher)
+
+ yield publisher
+
+ n.cancel_scope.cancel()
+
+
+@acm
+async def open_ringbuf_subscriber(
+ # name to distinguish this subscriber
+ topic: str = 'default',
+
+ decoder: Decoder | None = None,
+
+ # expect indexed payloads and unwrap them in order
+ guarantee_order: bool = False,
+
+ # on creation, set the `_subscriber` global in order to use the provided
+ # tractor.context & helper utils for adding and removing new channels from
+ # remote actors
+ set_module_var: bool = True
+) -> AsyncContextManager[RingBufferPublisher]:
+ '''
+ Open a new ringbuf subscriber
+
+ '''
+ async with (
+ trio.open_nursery(strict_exception_groups=False) as n,
+ RingBufferSubscriber(n, decoder=decoder) as subscriber
+ ):
+ # maybe monkey patch `.receive` to use indexed payloads
+ if guarantee_order:
+ order_receive_channel(subscriber)
+
+ # maybe set global module var for remote actor channel updates
+ if set_module_var:
+ set_subscriber(topic, subscriber)
+
+ yield subscriber
+
+ n.cancel_scope.cancel()
diff --git a/tractor/ipc/_transport.py b/tractor/ipc/_transport.py
index 6bfa5f6a..eb8ff3c9 100644
--- a/tractor/ipc/_transport.py
+++ b/tractor/ipc/_transport.py
@@ -78,7 +78,7 @@ class MsgTransport(Protocol):
# eventual msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol
- stream: trio.SocketStream
+ stream: trio.abc.Stream
drained: list[MsgType]
address_type: ClassVar[Type[Address]]
diff --git a/tractor/linux/__init__.py b/tractor/linux/__init__.py
new file mode 100644
index 00000000..33526d14
--- /dev/null
+++ b/tractor/linux/__init__.py
@@ -0,0 +1,15 @@
+# tractor: structured concurrent "actors".
+# Copyright 2018-eternity Tyler Goodlet.
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
diff --git a/tractor/linux/_fdshare.py b/tractor/linux/_fdshare.py
new file mode 100644
index 00000000..84681455
--- /dev/null
+++ b/tractor/linux/_fdshare.py
@@ -0,0 +1,316 @@
+# tractor: structured concurrent "actors".
+# Copyright 2018-eternity Tyler Goodlet.
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+'''
+Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
+
+cpython impl:
+https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
+'''
+import os
+import array
+import tempfile
+from uuid import uuid4
+from pathlib import Path
+from typing import AsyncContextManager
+from contextlib import asynccontextmanager as acm
+
+import trio
+import tractor
+from trio import socket
+
+
+log = tractor.log.get_logger(__name__)
+
+
+class FDSharingError(Exception):
+ ...
+
+
+@acm
+async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
+ '''
+ Async trio reimplementation of `multiprocessing.reduction.sendfds`
+
+ https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
+
+ It's implemented using an async context manager in order to simplyfy usage
+ with `tractor.context`s, we can open a context in a remote actor that uses
+ this acm inside of it, and uses `ctx.started()` to signal the original
+ caller actor to perform the `recv_fds` call.
+
+ See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
+ '''
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ await sock.bind(sock_path)
+ sock.listen(1)
+
+ yield # socket is setup, ready for receiver connect
+
+ # wait until receiver connects
+ conn, _ = await sock.accept()
+
+ # setup int array for fds
+ fds = array.array('i', fds)
+
+ # first byte of msg will be len of fds to send % 256, acting as a fd amount
+ # verification on `recv_fds` we refer to it as `check_byte`
+ msg = bytes([len(fds) % 256])
+
+ # send msg with custom SCM_RIGHTS type
+ await conn.sendmsg(
+ [msg],
+ [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
+ )
+
+ # finally wait receiver ack
+ if await conn.recv(1) != b'A':
+ raise FDSharingError('did not receive acknowledgement of fd')
+
+ conn.close()
+ sock.close()
+ os.unlink(sock_path)
+
+
+async def recv_fds(sock_path: str, amount: int) -> tuple:
+ '''
+ Async trio reimplementation of `multiprocessing.reduction.recvfds`
+
+ https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
+
+ It's equivalent to std just using `trio.open_unix_socket` for connecting and
+ changes on error handling.
+
+ See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
+ '''
+ stream = await trio.open_unix_socket(sock_path)
+ sock = stream.socket
+
+ # prepare int array for fds
+ a = array.array('i')
+ bytes_size = a.itemsize * amount
+
+ # receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
+ msg, ancdata, flags, addr = await sock.recvmsg(
+ 1, socket.CMSG_SPACE(bytes_size)
+ )
+
+ # maybe failed to receive msg?
+ if not msg and not ancdata:
+ raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
+
+ # send ack, std comment mentions this ack pattern was to get around an
+ # old macosx bug, but they are not sure if its necesary any more, in
+ # any case its not a bad pattern to keep
+ await sock.send(b'A') # Ack
+
+ # expect to receive only one `ancdata` item
+ if len(ancdata) != 1:
+ raise FDSharingError(
+ f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
+ )
+
+ # unpack SCM_RIGHTS msg
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+
+ # check proper msg type
+ if cmsg_level != socket.SOL_SOCKET:
+ raise FDSharingError(
+ f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
+ )
+
+ if cmsg_type != socket.SCM_RIGHTS:
+ raise FDSharingError(
+ f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
+ )
+
+ # check proper data alignment
+ length = len(cmsg_data)
+ if length % a.itemsize != 0:
+ raise FDSharingError(
+ f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
+ )
+
+ # attempt to cast as int array
+ a.frombytes(cmsg_data)
+
+ # validate length check byte
+ valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
+ recvd_check_byte = msg[0] # actual received check byte
+ payload_check_byte = len(a) % 256 # check byte acording to received fd int array
+
+ if recvd_check_byte != payload_check_byte:
+ raise FDSharingError(
+ 'Validation failed: received check byte '
+ f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
+ )
+
+ if valid_check_byte != recvd_check_byte:
+ raise FDSharingError(
+ 'Validation failed: received check byte '
+ f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
+ )
+
+ return tuple(a)
+
+
+'''
+Share FD actor module
+
+Add "tractor.linux._fdshare" to enabled modules on actors to allow sharing of
+FDs with other actors.
+
+Use `share_fds` function to register a set of fds with a name, then other
+actors can use `request_fds_from` function to retrieve the fds.
+
+Use `unshare_fds` to disable sharing of a set of FDs.
+
+'''
+
+FDType = tuple[int]
+
+_fds: dict[str, FDType] = {}
+
+
+def maybe_get_fds(name: str) -> FDType | None:
+ '''
+ Get registered FDs with a given name or return None
+
+ '''
+ return _fds.get(name, None)
+
+
+def get_fds(name: str) -> FDType:
+ '''
+ Get registered FDs with a given name or raise
+ '''
+ fds = maybe_get_fds(name)
+
+ if not fds:
+ raise RuntimeError(f'No FDs with name {name} found!')
+
+ return fds
+
+
+def share_fds(
+ name: str,
+ fds: tuple[int],
+) -> None:
+ '''
+ Register a set of fds to be shared under a given name.
+
+ '''
+ this_actor = tractor.current_actor()
+ if __name__ not in this_actor.enable_modules:
+ raise RuntimeError(
+ f'Tried to share FDs {fds} with name {name}, but '
+ f'module {__name__} is not enabled in actor {this_actor.name}!'
+ )
+
+ maybe_fds = maybe_get_fds(name)
+ if maybe_fds:
+ raise RuntimeError(f'share FDs: {maybe_fds} already tied to name {name}')
+
+ _fds[name] = fds
+
+
+def unshare_fds(name: str) -> None:
+ '''
+ Unregister a set of fds to disable sharing them.
+
+ '''
+ get_fds(name) # raise if not exists
+
+ del _fds[name]
+
+
+@tractor.context
+async def _pass_fds(
+ ctx: tractor.Context,
+ name: str,
+ sock_path: str
+) -> None:
+ '''
+ Endpoint to request a set of FDs from current actor, will use `ctx.started`
+ to send original FDs, then `send_fds` will block until remote side finishes
+ the `recv_fds` call.
+
+ '''
+ # get fds or raise error
+ fds = get_fds(name)
+
+ # start fd passing context using socket on `sock_path`
+ async with send_fds(fds, sock_path):
+ # send original fds through ctx.started
+ await ctx.started(fds)
+
+
+async def request_fds_from(
+ actor_name: str,
+ fds_name: str
+) -> FDType:
+ '''
+ Use this function to retreive shared FDs from `actor_name`.
+
+ '''
+ this_actor = tractor.current_actor()
+
+ # create a temporary path for the UDS sock
+ sock_path = str(
+ Path(tempfile.gettempdir())
+ /
+ f'{fds_name}-from-{actor_name}-to-{this_actor.name}.sock'
+ )
+
+ # having a socket path length > 100 aprox can cause:
+ # OSError: AF_UNIX path too long
+ # https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html#tag_13_67_04
+
+ # attempt sock path creation with smaller names
+ if len(sock_path) > 100:
+ sock_path = str(
+ Path(tempfile.gettempdir())
+ /
+ f'{fds_name}-to-{this_actor.name}.sock'
+ )
+
+ if len(sock_path) > 100:
+ # just use uuid4
+ sock_path = str(
+ Path(tempfile.gettempdir())
+ /
+ f'pass-fds-{uuid4()}.sock'
+ )
+
+ async with (
+ tractor.find_actor(actor_name) as portal,
+
+ portal.open_context(
+ _pass_fds,
+ name=fds_name,
+ sock_path=sock_path
+ ) as (ctx, fds_info),
+ ):
+ # get original FDs
+ og_fds = fds_info
+
+ # retrieve copies of FDs
+ fds = await recv_fds(sock_path, len(og_fds))
+
+ log.info(
+ f'{this_actor.name} received fds: {og_fds} -> {fds}'
+ )
+
+ return fds
diff --git a/tractor/ipc/_linux.py b/tractor/linux/eventfd.py
similarity index 68%
rename from tractor/ipc/_linux.py
rename to tractor/linux/eventfd.py
index 88d80d1c..1b00a190 100644
--- a/tractor/ipc/_linux.py
+++ b/tractor/linux/eventfd.py
@@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
'''
-Linux specifics, for now we are only exposing EventFD
+Expose libc eventfd APIs
'''
import os
@@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
raise OSError(errno.errorcode[ffi.errno], 'close failed')
+class EFDReadCancelled(Exception):
+ ...
+
+
class EventFD:
'''
Use a previously opened eventfd(2), meant to be used in
@@ -124,26 +128,82 @@ class EventFD:
self._fd: int = fd
self._omode: str = omode
self._fobj = None
+ self._cscope: trio.CancelScope | None = None
+ self._is_closed: bool = True
+ self._read_lock = trio.StrictFIFOLock()
+
+ @property
+ def closed(self) -> bool:
+ return self._is_closed
@property
def fd(self) -> int | None:
return self._fd
def write(self, value: int) -> int:
+ if self.closed:
+ raise trio.ClosedResourceError
+
return write_eventfd(self._fd, value)
async def read(self) -> int:
- return await trio.to_thread.run_sync(
- read_eventfd, self._fd,
- abandon_on_cancel=True
- )
+ '''
+ Async wrapper for `read_eventfd(self.fd)`
+
+ `trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
+ in order to make it cancellable when `self.close()` is called.
+
+ '''
+ if self.closed:
+ raise trio.ClosedResourceError
+
+ if self._read_lock.locked():
+ raise trio.BusyResourceError
+
+ async with self._read_lock:
+ self._cscope = trio.CancelScope()
+ with self._cscope:
+ try:
+ return await trio.to_thread.run_sync(
+ read_eventfd, self._fd,
+ abandon_on_cancel=True
+ )
+
+ except OSError as e:
+ if e.errno != errno.EBADF:
+ raise
+
+ raise trio.BrokenResourceError
+
+ if self._cscope.cancelled_caught:
+ raise EFDReadCancelled
+
+ self._cscope = None
+
+ def read_nowait(self) -> int:
+ '''
+ Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
+ opened with `EFD_NONBLOCK` its gonna block the thread.
+
+ '''
+ return read_eventfd(self._fd)
def open(self):
self._fobj = os.fdopen(self._fd, self._omode)
+ self._is_closed = False
def close(self):
if self._fobj:
- self._fobj.close()
+ try:
+ self._fobj.close()
+
+ except OSError:
+ ...
+
+ if self._cscope:
+ self._cscope.cancel()
+
+ self._is_closed = True
def __enter__(self):
self.open()
diff --git a/tractor/trionics/__init__.py b/tractor/trionics/__init__.py
index df9b6f26..97d03da7 100644
--- a/tractor/trionics/__init__.py
+++ b/tractor/trionics/__init__.py
@@ -32,3 +32,8 @@ from ._broadcast import (
from ._beg import (
collapse_eg as collapse_eg,
)
+
+from ._ordering import (
+ order_send_channel as order_send_channel,
+ order_receive_channel as order_receive_channel
+)
diff --git a/tractor/trionics/_mngrs.py b/tractor/trionics/_mngrs.py
index 9a5ed156..24b4fde8 100644
--- a/tractor/trionics/_mngrs.py
+++ b/tractor/trionics/_mngrs.py
@@ -70,7 +70,8 @@ async def maybe_open_nursery(
yield nursery
else:
async with lib.open_nursery(**kwargs) as nursery:
- nursery.cancel_scope.shield = shield
+ if lib == trio:
+ nursery.cancel_scope.shield = shield
yield nursery
diff --git a/tractor/trionics/_ordering.py b/tractor/trionics/_ordering.py
new file mode 100644
index 00000000..0cc89b4b
--- /dev/null
+++ b/tractor/trionics/_ordering.py
@@ -0,0 +1,108 @@
+# tractor: structured concurrent "actors".
+# Copyright 2018-eternity Tyler Goodlet.
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+'''
+Helpers to guarantee ordering of messages through a unordered channel
+
+'''
+from __future__ import annotations
+from heapq import (
+ heappush,
+ heappop
+)
+
+import trio
+import msgspec
+
+
+class OrderedPayload(msgspec.Struct, frozen=True):
+ index: int
+ payload: bytes
+
+ @classmethod
+ def from_msg(cls, msg: bytes) -> OrderedPayload:
+ return msgspec.msgpack.decode(msg, type=OrderedPayload)
+
+ def encode(self) -> bytes:
+ return msgspec.msgpack.encode(self)
+
+
+def order_send_channel(
+ channel: trio.abc.SendChannel[bytes],
+ start_index: int = 0
+):
+
+ next_index = start_index
+ send_lock = trio.StrictFIFOLock()
+
+ channel._send = channel.send
+ channel._aclose = channel.aclose
+
+ async def send(msg: bytes):
+ nonlocal next_index
+ async with send_lock:
+ await channel._send(
+ OrderedPayload(
+ index=next_index,
+ payload=msg
+ ).encode()
+ )
+ next_index += 1
+
+ async def aclose():
+ async with send_lock:
+ await channel._aclose()
+
+ channel.send = send
+ channel.aclose = aclose
+
+
+def order_receive_channel(
+ channel: trio.abc.ReceiveChannel[bytes],
+ start_index: int = 0
+):
+ next_index = start_index
+ pqueue = []
+
+ channel._receive = channel.receive
+
+ def can_pop_next() -> bool:
+ return (
+ len(pqueue) > 0
+ and
+ pqueue[0][0] == next_index
+ )
+
+ async def drain_to_heap():
+ while not can_pop_next():
+ msg = await channel._receive()
+ msg = OrderedPayload.from_msg(msg)
+ heappush(pqueue, (msg.index, msg.payload))
+
+ def pop_next():
+ nonlocal next_index
+ _, msg = heappop(pqueue)
+ next_index += 1
+ return msg
+
+ async def receive() -> bytes:
+ if can_pop_next():
+ return pop_next()
+
+ await drain_to_heap()
+
+ return pop_next()
+
+ channel.receive = receive