Linux specific IPC RingBuff using EventFD for async reader wakeup #10

Open
guille wants to merge 41 commits from one_ring_to_rule_them_all into auto_codecs
16 changed files with 2982 additions and 404 deletions

View File

@ -0,0 +1,66 @@
import trio
import pytest
from tractor.linux.eventfd import (
open_eventfd,
EFDReadCancelled,
EventFD
)
def test_read_cancellation():
'''
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
is called.
'''
fd = open_eventfd()
async def bg_read(event: EventFD):
with pytest.raises(EFDReadCancelled):
await event.read()
async def main():
async with trio.open_nursery() as n:
with (
EventFD(fd, 'w') as event,
trio.fail_after(3)
):
n.start_soon(bg_read, event)
await trio.sleep(0.2)
event.close()
trio.run(main)
def test_read_trio_semantics():
'''
Ensure EventFD.read raises trio.ClosedResourceError and
trio.BusyResourceError.
'''
fd = open_eventfd()
async def bg_read(event: EventFD):
try:
await event.read()
except EFDReadCancelled:
...
async def main():
async with trio.open_nursery() as n:
# start background read and attempt
# foreground read, should be busy
with EventFD(fd, 'w') as event:
n.start_soon(bg_read, event)
await trio.sleep(0.2)
with pytest.raises(trio.BusyResourceError):
await event.read()
# attempt read after close
with pytest.raises(trio.ClosedResourceError):
await event.read()
trio.run(main)

View File

@ -0,0 +1,185 @@
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm
import trio
import pytest
import tractor
from tractor.trionics import gather_contexts
from tractor.ipc._ringbuf import open_ringbufs
from tractor.ipc._ringbuf._pubsub import (
open_ringbuf_publisher,
open_ringbuf_subscriber,
get_publisher,
get_subscriber,
open_pub_channel_at,
open_sub_channel_at
)
log = tractor.log.get_console_log(level='info')
@tractor.context
async def publish_range(
ctx: tractor.Context,
size: int
):
pub = get_publisher()
await ctx.started()
for i in range(size):
await pub.send(i.to_bytes(4))
log.info(f'sent {i}')
await pub.flush()
log.info('range done')
@tractor.context
async def subscribe_range(
ctx: tractor.Context,
size: int
):
sub = get_subscriber()
await ctx.started()
for i in range(size):
recv = int.from_bytes(await sub.receive())
if recv != i:
raise AssertionError(
f'received: {recv} expected: {i}'
)
log.info(f'received: {recv}')
log.info('range done')
@tractor.context
async def subscriber_child(ctx: tractor.Context):
try:
async with open_ringbuf_subscriber(guarantee_order=True):
await ctx.started()
await trio.sleep_forever()
finally:
log.info('subscriber exit')
@tractor.context
async def publisher_child(
ctx: tractor.Context,
batch_size: int
):
try:
async with open_ringbuf_publisher(
guarantee_order=True,
batch_size=batch_size
):
await ctx.started()
await trio.sleep_forever()
finally:
log.info('publisher exit')
@acm
async def open_pubsub_test_actors(
ring_names: list[str],
size: int,
batch_size: int
) -> AsyncContextManager[tuple[tractor.Portal, tractor.Portal]]:
with trio.fail_after(5):
async with tractor.open_nursery(
enable_modules=[
'tractor.linux._fdshare'
]
) as an:
modules = [
__name__,
'tractor.linux._fdshare',
'tractor.ipc._ringbuf._pubsub'
]
sub_portal = await an.start_actor(
'sub',
enable_modules=modules
)
pub_portal = await an.start_actor(
'pub',
enable_modules=modules
)
async with (
sub_portal.open_context(subscriber_child) as (long_rctx, _),
pub_portal.open_context(
publisher_child,
batch_size=batch_size
) as (long_sctx, _),
open_ringbufs(ring_names) as tokens,
gather_contexts([
open_sub_channel_at('sub', ring)
for ring in tokens
]),
gather_contexts([
open_pub_channel_at('pub', ring)
for ring in tokens
]),
sub_portal.open_context(subscribe_range, size=size) as (rctx, _),
pub_portal.open_context(publish_range, size=size) as (sctx, _)
):
yield
await rctx.wait_for_result()
await sctx.wait_for_result()
await long_sctx.cancel()
await long_rctx.cancel()
await an.cancel()
@pytest.mark.parametrize(
('ring_names', 'size', 'batch_size'),
[
(
['ring-first'],
100,
1
),
(
['ring-first'],
69,
1
),
(
[f'multi-ring-{i}' for i in range(3)],
1000,
100
),
],
ids=[
'simple',
'redo-simple',
'multi-ring',
]
)
def test_pubsub(
request,
ring_names: list[str],
size: int,
batch_size: int
):
async def main():
async with open_pubsub_test_actors(
ring_names, size, batch_size
):
...
trio.run(main)

View File

@ -1,4 +1,5 @@
import time
import hashlib
import trio
import pytest
@ -6,36 +7,45 @@ import pytest
import tractor
from tractor.ipc._ringbuf import (
open_ringbuf,
open_ringbuf_pair,
attach_to_ringbuf_receiver,
attach_to_ringbuf_sender,
attach_to_ringbuf_channel,
RBToken,
RingBuffSender,
RingBuffReceiver
)
from tractor._testing.samples import (
generate_sample_messages,
generate_single_byte_msgs,
RandomBytesGenerator
)
# in case you don't want to melt your cores, uncomment dis!
pytestmark = pytest.mark.skip
@tractor.context
async def child_read_shm(
ctx: tractor.Context,
msg_amount: int,
token: RBToken,
total_bytes: int,
) -> None:
recvd_bytes = 0
await ctx.started()
start_ts = time.time()
async with RingBuffReceiver(token) as receiver:
while recvd_bytes < total_bytes:
msg = await receiver.receive_some()
recvd_bytes += len(msg)
) -> str:
'''
Sub-actor used in `test_ringbuf`.
# make sure we dont hold any memoryviews
# before the ctx manager aclose()
msg = None
Attach to a ringbuf and receive all messages until end of stream.
Keep track of how many bytes received and also calculate
sha256 of the whole byte stream.
Calculate and print performance stats, finally return calculated
hash.
'''
await ctx.started()
print('reader started')
msg_amount = 0
recvd_bytes = 0
recvd_hash = hashlib.sha256()
start_ts = time.time()
async with attach_to_ringbuf_receiver(token) as receiver:
async for msg in receiver:
msg_amount += 1
recvd_hash.update(msg)
recvd_bytes += len(msg)
end_ts = time.time()
elapsed = end_ts - start_ts
@ -44,6 +54,10 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived msgs: {msg_amount:,}')
print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest()
@tractor.context
@ -52,17 +66,37 @@ async def child_write_shm(
msg_amount: int,
rand_min: int,
rand_max: int,
token: RBToken,
buf_size: int
) -> None:
msgs, total_bytes = generate_sample_messages(
'''
Sub-actor used in `test_ringbuf`
Generate `msg_amount` payloads with
`random.randint(rand_min, rand_max)` random bytes at the end,
Calculate sha256 hash and send it to parent on `ctx.started`.
Attach to ringbuf and send all generated messages.
'''
rng = RandomBytesGenerator(
msg_amount,
rand_min=rand_min,
rand_max=rand_max,
)
await ctx.started(total_bytes)
async with RingBuffSender(token) as sender:
for msg in msgs:
await sender.send_all(msg)
async with (
open_ringbuf('test_ringbuf', buf_size=buf_size) as token,
attach_to_ringbuf_sender(token) as sender
):
await ctx.started(token)
print('writer started')
for msg in rng:
await sender.send(msg)
if rng.msgs_generated % rng.recommended_log_interval == 0:
print(f'wrote {rng.msgs_generated} msgs')
print('writer exit')
return rng.hexdigest
@pytest.mark.parametrize(
@ -89,84 +123,91 @@ def test_ringbuf(
rand_max: int,
buf_size: int
):
'''
- Open a new ring buf on root actor
- Open `child_write_shm` ctx in sub-actor which will generate a
random payload and send its hash on `ctx.started`, finally sending
the payload through the stream.
- Open `child_read_shm` ctx in sub-actor which will receive the
payload, calculate perf stats and return the hash.
- Compare both hashes
'''
async def main():
with open_ringbuf(
'test_ringbuf',
buf_size=buf_size
) as token:
proc_kwargs = {
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
async with tractor.open_nursery() as an:
send_p = await an.start_actor(
'ring_sender',
enable_modules=[
__name__,
'tractor.linux._fdshare'
],
)
recv_p = await an.start_actor(
'ring_receiver',
enable_modules=[
__name__,
'tractor.linux._fdshare'
],
)
async with (
send_p.open_context(
child_write_shm,
msg_amount=msg_amount,
rand_min=rand_min,
rand_max=rand_max,
buf_size=buf_size
) as (sctx, token),
common_kwargs = {
'msg_amount': msg_amount,
'token': token,
}
async with tractor.open_nursery() as an:
send_p = await an.start_actor(
'ring_sender',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
recv_p = await an.start_actor(
'ring_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
send_p.open_context(
child_write_shm,
rand_min=rand_min,
rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes),
recv_p.open_context(
child_read_shm,
**common_kwargs,
total_bytes=total_bytes,
) as (sctx, _sent),
):
await recv_p.result()
recv_p.open_context(
child_read_shm,
token=token,
) as (rctx, _),
):
sent_hash = await sctx.result()
recvd_hash = await rctx.result()
await send_p.cancel_actor()
await recv_p.cancel_actor()
assert sent_hash == recvd_hash
await an.cancel()
trio.run(main)
@tractor.context
async def child_blocked_receiver(
ctx: tractor.Context,
token: RBToken
):
async with RingBuffReceiver(token) as receiver:
await ctx.started()
async def child_blocked_receiver(ctx: tractor.Context):
async with (
open_ringbuf('test_ring_cancel_reader') as token,
attach_to_ringbuf_receiver(token) as receiver
):
await ctx.started(token)
await receiver.receive_some()
def test_ring_reader_cancel():
def test_reader_cancel():
'''
Test that a receiver blocked on eventfd(2) read responds to
cancellation.
'''
async def main():
with open_ringbuf('test_ring_cancel_reader') as token:
async with tractor.open_nursery() as an:
recv_p = await an.start_actor(
'ring_blocked_receiver',
enable_modules=[
__name__,
'tractor.linux._fdshare'
],
)
async with (
tractor.open_nursery() as an,
RingBuffSender(token) as _sender,
recv_p.open_context(
child_blocked_receiver,
) as (sctx, token),
attach_to_ringbuf_sender(token),
):
recv_p = await an.start_actor(
'ring_blocked_receiver',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
)
async with (
recv_p.open_context(
child_blocked_receiver,
token=token
) as (sctx, _sent),
):
await trio.sleep(1)
await an.cancel()
await trio.sleep(.1)
await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
@ -174,38 +215,166 @@ def test_ring_reader_cancel():
@tractor.context
async def child_blocked_sender(
ctx: tractor.Context,
token: RBToken
):
async with RingBuffSender(token) as sender:
await ctx.started()
async def child_blocked_sender(ctx: tractor.Context):
async with (
open_ringbuf(
'test_ring_cancel_sender',
buf_size=1
) as token,
attach_to_ringbuf_sender(token) as sender
):
await ctx.started(token)
await sender.send_all(b'this will wrap')
def test_ring_sender_cancel():
def test_sender_cancel():
'''
Test that a sender blocked on eventfd(2) read responds to
cancellation.
'''
async def main():
with open_ringbuf(
'test_ring_cancel_sender',
buf_size=1
) as token:
async with tractor.open_nursery() as an:
recv_p = await an.start_actor(
'ring_blocked_sender',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
)
async with (
recv_p.open_context(
child_blocked_sender,
token=token
) as (sctx, _sent),
):
await trio.sleep(1)
await an.cancel()
async with tractor.open_nursery() as an:
recv_p = await an.start_actor(
'ring_blocked_sender',
enable_modules=[
__name__,
'tractor.linux._fdshare'
],
)
async with (
recv_p.open_context(
child_blocked_sender,
) as (sctx, token),
attach_to_ringbuf_receiver(token)
):
await trio.sleep(.1)
await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main)
def test_receiver_max_bytes():
'''
Test that RingBuffReceiver.receive_some's max_bytes optional
argument works correctly, send a msg of size 100, then
force receive of messages with max_bytes == 1, wait until
100 of these messages are received, then compare join of
msgs with original message
'''
msg = generate_single_byte_msgs(100)
msgs = []
rb_common = {
'cleanup': False,
'is_ipc': False
}
async def main():
async with (
open_ringbuf(
'test_ringbuf_max_bytes',
buf_size=10,
is_ipc=False
) as token,
trio.open_nursery() as n,
attach_to_ringbuf_sender(token, **rb_common) as sender,
attach_to_ringbuf_receiver(token, **rb_common) as receiver
):
async def _send_and_close():
await sender.send_all(msg)
await sender.aclose()
n.start_soon(_send_and_close)
while len(msgs) < len(msg):
msg_part = await receiver.receive_some(max_bytes=1)
assert len(msg_part) == 1
msgs.append(msg_part)
trio.run(main)
assert msg == b''.join(msgs)
@tractor.context
async def child_channel_sender(
ctx: tractor.Context,
msg_amount_min: int,
msg_amount_max: int,
token_in: RBToken,
token_out: RBToken
):
import random
rng = RandomBytesGenerator(
random.randint(msg_amount_min, msg_amount_max),
rand_min=256,
rand_max=1024,
)
async with attach_to_ringbuf_channel(
token_in,
token_out
) as chan:
await ctx.started()
for msg in rng:
await chan.send(msg)
await chan.send(b'bye')
await chan.receive()
return rng.hexdigest
def test_channel():
msg_amount_min = 100
msg_amount_max = 1000
mods = [
__name__,
'tractor.linux._fdshare'
]
async def main():
async with (
tractor.open_nursery(enable_modules=mods) as an,
open_ringbuf_pair(
'test_ringbuf_transport'
) as (send_token, recv_token),
attach_to_ringbuf_channel(send_token, recv_token) as chan,
):
sender = await an.start_actor(
'test_ringbuf_transport_sender',
enable_modules=mods,
)
async with (
sender.open_context(
child_channel_sender,
msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max,
token_in=recv_token,
token_out=send_token
) as (ctx, _),
):
recvd_hash = hashlib.sha256()
async for msg in chan:
if msg == b'bye':
await chan.send(b'bye')
break
recvd_hash.update(msg)
sent_hash = await ctx.result()
assert recvd_hash.hexdigest() == sent_hash
await an.cancel()
trio.run(main)

View File

@ -121,9 +121,14 @@ def get_peer_by_name(
actor: Actor = current_actor()
server: IPCServer = actor.ipc_server
to_scan: dict[tuple, list[Channel]] = server._peers.copy()
pchan: Channel|None = actor._parent_chan
if pchan:
to_scan[pchan.uid].append(pchan)
# TODO: is this ever needed? creates a duplicate channel on actor._peers
# when multiple find_actor calls are made to same actor from a single ctx
# which causes actor exit to hang waiting forever on
# `actor._no_more_peers.wait()` in `_runtime.async_main`
# pchan: Channel|None = actor._parent_chan
# if pchan:
# to_scan[pchan.uid].append(pchan)
for aid, chans in to_scan.items():
_, peer_name = aid

View File

@ -1,35 +1,99 @@
import os
import random
import hashlib
import numpy as np
def generate_sample_messages(
amount: int,
rand_min: int = 0,
rand_max: int = 0,
silent: bool = False
) -> tuple[list[bytes], int]:
def generate_single_byte_msgs(amount: int) -> bytes:
'''
Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
msgs = []
size = 0
'''
# array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
# up to 9->'9'(57).
arr = np.arange(amount, dtype=np.uint8) % 10
# move into ascii space
arr += 48
return arr.tobytes()
if not silent:
print(f'\ngenerating {amount} messages...')
for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8')
class RandomBytesGenerator:
'''
Generate bytes msgs for tests.
if rand_max > 0:
msg += os.urandom(
random.randint(rand_min, rand_max))
messages will have the following format:
size += len(msg)
b'[{i:08}]' + random_bytes
msgs.append(msg)
so for message index 25:
if not silent and i and i % 10_000 == 0:
print(f'{i} generated')
b'[00000025]' + random_bytes
if not silent:
print(f'done, {size:,} bytes in total')
also generates sha256 hash of msgs.
return msgs, size
'''
def __init__(
self,
amount: int,
rand_min: int = 0,
rand_max: int = 0
):
if rand_max < rand_min:
raise ValueError('rand_max must be >= rand_min')
self._amount = amount
self._rand_min = rand_min
self._rand_max = rand_max
self._index = 0
self._hasher = hashlib.sha256()
self._total_bytes = 0
self._lengths = np.random.randint(
rand_min,
rand_max + 1,
size=amount,
dtype=np.int32
)
def __iter__(self):
return self
def __next__(self) -> bytes:
if self._index == self._amount:
raise StopIteration
header = f'[{self._index:08}]'.encode('utf-8')
length = int(self._lengths[self._index])
msg = header + np.random.bytes(length)
self._hasher.update(msg)
self._total_bytes += length
self._index += 1
return msg
@property
def hexdigest(self) -> str:
return self._hasher.hexdigest()
@property
def total_bytes(self) -> int:
return self._total_bytes
@property
def total_msgs(self) -> int:
return self._amount
@property
def msgs_generated(self) -> int:
return self._index
@property
def recommended_log_interval(self) -> int:
max_msg_size = 10 + self._rand_max
if max_msg_size <= 32 * 1024:
return 10_000
else:
return 1000

View File

@ -13,7 +13,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
A modular IPC layer supporting the power of cross-process SC!

View File

@ -1,253 +0,0 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
IPC Reliable RingBuffer implementation
'''
from __future__ import annotations
from contextlib import contextmanager as cm
from multiprocessing.shared_memory import SharedMemory
import trio
from msgspec import (
Struct,
to_builtins
)
from ._linux import (
EFD_NONBLOCK,
open_eventfd,
EventFD
)
from ._mp_bs import disable_mantracker
disable_mantracker()
class RBToken(Struct, frozen=True):
'''
RingBuffer token contains necesary info to open the two
eventfds and the shared memory
'''
shm_name: str
write_eventfd: int
wrap_eventfd: int
buf_size: int
def as_msg(self):
return to_builtins(self)
@classmethod
def from_msg(cls, msg: dict) -> RBToken:
if isinstance(msg, RBToken):
return msg
return RBToken(**msg)
@cm
def open_ringbuf(
shm_name: str,
buf_size: int = 10 * 1024,
write_efd_flags: int = 0,
wrap_efd_flags: int = 0
) -> RBToken:
shm = SharedMemory(
name=shm_name,
size=buf_size,
create=True
)
try:
token = RBToken(
shm_name=shm_name,
write_eventfd=open_eventfd(flags=write_efd_flags),
wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
buf_size=buf_size
)
yield token
finally:
shm.unlink()
class RingBuffSender(trio.abc.SendStream):
'''
IPC Reliable Ring Buffer sender side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
'''
def __init__(
self,
token: RBToken,
start_ptr: int = 0,
):
token = RBToken.from_msg(token)
self._shm = SharedMemory(
name=token.shm_name,
size=token.buf_size,
create=False
)
self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._ptr = start_ptr
@property
def key(self) -> str:
return self._shm.name
@property
def size(self) -> int:
return self._shm.size
@property
def ptr(self) -> int:
return self._ptr
@property
def write_fd(self) -> int:
return self._write_event.fd
@property
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def send_all(self, data: bytes | bytearray | memoryview):
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
while target_ptr > self.size:
# write all bytes that fit
remaining = self.size - self.ptr
self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around
self._write_event.write(remaining)
await self._wrap_event.read()
# wrap around and trim already written bytes
self._ptr = 0
data = data[remaining:]
target_ptr = self._ptr + len(data)
# remaining data fits on buffer
self._shm.buf[self.ptr:target_ptr] = data
self._write_event.write(len(data))
self._ptr = target_ptr
async def wait_send_all_might_not_block(self):
raise NotImplementedError
async def aclose(self):
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self._write_event.open()
self._wrap_event.open()
return self
class RingBuffReceiver(trio.abc.ReceiveStream):
'''
IPC Reliable Ring Buffer receiver side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
'''
def __init__(
self,
token: RBToken,
start_ptr: int = 0,
flags: int = 0
):
token = RBToken.from_msg(token)
self._shm = SharedMemory(
name=token.shm_name,
size=token.buf_size,
create=False
)
self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._ptr = start_ptr
self._flags = flags
@property
def key(self) -> str:
return self._shm.name
@property
def size(self) -> int:
return self._shm.size
@property
def ptr(self) -> int:
return self._ptr
@property
def write_fd(self) -> int:
return self._write_event.fd
@property
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def receive_some(
self,
max_bytes: int | None = None,
nb_timeout: float = 0.1
) -> memoryview:
# if non blocking eventfd enabled, do polling
# until next write, this allows signal handling
if self._flags | EFD_NONBLOCK:
delta = None
while delta is None:
try:
delta = await self._write_event.read()
except OSError as e:
if e.errno == 'EAGAIN':
continue
raise e
else:
delta = await self._write_event.read()
# fetch next segment and advance ptr
next_ptr = self._ptr + delta
segment = self._shm.buf[self._ptr:next_ptr]
self._ptr = next_ptr
if self.ptr == self.size:
# reached the end, signal wrap around
self._ptr = 0
self._wrap_event.write(1)
return segment
async def aclose(self):
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self._write_event.open()
self._wrap_event.open()
return self

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,834 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Ring buffer ipc publish-subscribe mechanism brokered by ringd
can dynamically add new outputs (publisher) or inputs (subscriber)
'''
from typing import (
TypeVar,
Generic,
Callable,
Awaitable,
AsyncContextManager
)
from functools import partial
from contextlib import asynccontextmanager as acm
from dataclasses import dataclass
import trio
import tractor
from msgspec.msgpack import (
Encoder,
Decoder
)
from tractor.ipc._ringbuf import (
RBToken,
PayloadT,
RingBufferSendChannel,
RingBufferReceiveChannel,
attach_to_ringbuf_sender,
attach_to_ringbuf_receiver
)
from tractor.trionics import (
order_send_channel,
order_receive_channel
)
import tractor.linux._fdshare as fdshare
log = tractor.log.get_logger(__name__)
ChannelType = TypeVar('ChannelType')
@dataclass
class ChannelInfo:
token: RBToken
channel: ChannelType
cancel_scope: trio.CancelScope
teardown: trio.Event
class ChannelManager(Generic[ChannelType]):
'''
Helper for managing channel resources and their handler tasks with
cancellation, add or remove channels dynamically!
'''
def __init__(
self,
# nursery used to spawn channel handler tasks
n: trio.Nursery,
# acm will be used for setup & teardown of channel resources
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
# long running bg task to handle channel
channel_task: Callable[..., Awaitable[None]]
):
self._n = n
self._open_channel = open_channel_acm
self._channel_task = channel_task
# signal when a new channel conects and we previously had none
self._connect_event = trio.Event()
# store channel runtime variables
self._channels: list[ChannelInfo] = []
self._is_closed: bool = True
@property
def closed(self) -> bool:
return self._is_closed
@property
def channels(self) -> list[ChannelInfo]:
return self._channels
async def _channel_handler_task(
self,
token: RBToken,
task_status=trio.TASK_STATUS_IGNORED,
**kwargs
):
'''
Open channel resources, add to internal data structures, signal channel
connect through trio.Event, and run `channel_task` with cancel scope,
and finally, maybe remove channel from internal data structures.
Spawned by `add_channel` function, lock is held from begining of fn
until `task_status.started()` call.
kwargs are proxied to `self._open_channel` acm.
'''
async with self._open_channel(
token,
**kwargs
) as chan:
cancel_scope = trio.CancelScope()
info = ChannelInfo(
token=token,
channel=chan,
cancel_scope=cancel_scope,
teardown=trio.Event()
)
self._channels.append(info)
if len(self) == 1:
self._connect_event.set()
task_status.started()
with cancel_scope:
await self._channel_task(info)
self._maybe_destroy_channel(token.shm_name)
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
'''
Given a channel name maybe return its index and value from
internal _channels list.
Only use after acquiring lock.
'''
for entry in enumerate(self._channels):
i, info = entry
if info.token.shm_name == name:
return entry
return None
def _maybe_destroy_channel(self, name: str):
'''
If channel exists cancel its scope and remove from internal
_channels list.
'''
maybe_entry = self._find_channel(name)
if maybe_entry:
i, info = maybe_entry
info.cancel_scope.cancel()
info.teardown.set()
del self._channels[i]
async def add_channel(
self,
token: RBToken,
**kwargs
):
'''
Add a new channel to be handled
'''
if self.closed:
raise trio.ClosedResourceError
await self._n.start(partial(
self._channel_handler_task,
RBToken.from_msg(token),
**kwargs
))
async def remove_channel(self, name: str):
'''
Remove a channel and stop its handling
'''
if self.closed:
raise trio.ClosedResourceError
maybe_entry = self._find_channel(name)
if not maybe_entry:
# return
raise RuntimeError(
f'tried to remove channel {name} but if does not exist'
)
i, info = maybe_entry
self._maybe_destroy_channel(name)
await info.teardown.wait()
# if that was last channel reset connect event
if len(self) == 0:
self._connect_event = trio.Event()
async def wait_for_channel(self):
'''
Wait until at least one channel added
'''
if self.closed:
raise trio.ClosedResourceError
await self._connect_event.wait()
self._connect_event = trio.Event()
def __len__(self) -> int:
return len(self._channels)
def __getitem__(self, name: str):
maybe_entry = self._find_channel(name)
if maybe_entry:
_, info = maybe_entry
return info
raise KeyError(f'Channel {name} not found!')
def open(self):
self._is_closed = False
async def close(self) -> None:
if self.closed:
log.warning('tried to close ChannelManager but its already closed...')
return
for info in self._channels:
if info.channel.closed:
continue
await info.channel.aclose()
await self.remove_channel(info.token.shm_name)
self._is_closed = True
'''
Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
'''
class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
'''
Use ChannelManager to create a multi ringbuf round robin sender that can
dynamically add or remove more outputs.
Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
lifecycle.
'''
def __init__(
self,
n: trio.Nursery,
# amount of msgs to each ring before switching turns
msgs_per_turn: int = 1,
# global batch size for all channels
batch_size: int = 1,
encoder: Encoder | None = None
):
self._batch_size: int = batch_size
self.msgs_per_turn = msgs_per_turn
self._enc = encoder
# helper to manage acms + long running tasks
self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
n,
self._open_channel,
self._channel_task
)
# ensure no concurrent `.send()` calls
self._send_lock = trio.StrictFIFOLock()
# index of channel to be used for next send
self._next_turn: int = 0
# amount of messages sent this turn
self._turn_msgs: int = 0
# have we closed this publisher?
# set to `False` on `.__aenter__()`
self._is_closed: bool = True
@property
def closed(self) -> bool:
return self._is_closed
@property
def batch_size(self) -> int:
return self._batch_size
@batch_size.setter
def batch_size(self, value: int) -> None:
for info in self.channels:
info.channel.batch_size = value
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def _get_next_turn(self) -> int:
'''
Maybe switch turn and reset self._turn_msgs or just increment it.
Return current turn
'''
if self._turn_msgs == self.msgs_per_turn:
self._turn_msgs = 0
self._next_turn += 1
if self._next_turn >= len(self.channels):
self._next_turn = 0
else:
self._turn_msgs += 1
return self._next_turn
def get_channel(self, name: str) -> ChannelInfo:
'''
Get underlying ChannelInfo from name
'''
return self._chanmngr[name]
async def add_channel(
self,
token: RBToken,
):
await self._chanmngr.add_channel(token)
async def remove_channel(self, name: str):
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
self,
token: RBToken
) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
async with attach_to_ringbuf_sender(
token,
batch_size=self._batch_size,
encoder=self._enc
) as ring:
yield ring
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Wait forever until channel cancellation
'''
await trio.sleep_forever()
async def send(self, msg: bytes):
'''
If no output channels connected, wait until one, then fetch the next
channel based on turn.
Needs to acquire `self._send_lock` to ensure no concurrent calls.
'''
if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
async with self._send_lock:
# wait at least one decoder connected
if len(self.channels) == 0:
await self._chanmngr.wait_for_channel()
turn = self._get_next_turn()
info = self.channels[turn]
await info.channel.send(msg)
async def broadcast(self, msg: PayloadT):
'''
Send a msg to all channels, if no channels connected, does nothing.
'''
if self.closed:
raise trio.ClosedResourceError
for info in self.channels:
await info.channel.send(msg)
async def flush(self, new_batch_size: int | None = None):
for info in self.channels:
try:
await info.channel.flush(new_batch_size=new_batch_size)
except trio.ClosedResourceError:
...
async def __aenter__(self):
self._is_closed = False
self._chanmngr.open()
return self
async def aclose(self) -> None:
if self.closed:
log.warning('tried to close RingBufferPublisher but its already closed...')
return
await self._chanmngr.close()
self._is_closed = True
class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
'''
Use ChannelManager to create a multi ringbuf receiver that can
dynamically add or remove more inputs and combine all into a single output.
In order for `self.receive` messages to be returned in order, publisher
will send all payloads as `OrderedPayload` msgpack encoded msgs, this
allows our channel handler tasks to just stash the out of order payloads
inside `self._pending_payloads` and if a in order payload is available
signal through `self._new_payload_event`.
On `self.receive` we wait until at least one channel is connected, then if
an in order payload is pending, we pop and return it, in case no in order
payload is available wait until next `self._new_payload_event.set()`.
'''
def __init__(
self,
n: trio.Nursery,
decoder: Decoder | None = None
):
self._dec = decoder
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
n,
self._open_channel,
self._channel_task
)
self._schan, self._rchan = trio.open_memory_channel(0)
self._is_closed: bool = True
self._receive_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def get_channel(self, name: str):
return self._chanmngr[name]
async def add_channel(
self,
token: RBToken
):
await self._chanmngr.add_channel(token)
async def remove_channel(self, name: str):
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
self,
token: RBToken
) -> AsyncContextManager[RingBufferSendChannel]:
async with attach_to_ringbuf_receiver(
token,
decoder=self._dec
) as ring:
yield ring
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Iterate over receive channel messages, decode them as `OrderedPayload`s
and stash them in `self._pending_payloads`, in case we can pop next in
order payload, signal through setting `self._new_payload_event`.
'''
while True:
try:
msg = await info.channel.receive()
await self._schan.send(msg)
except tractor.linux.eventfd.EFDReadCancelled as e:
# when channel gets removed while we are doing a receive
log.exception(e)
break
except trio.EndOfChannel:
break
except trio.ClosedResourceError:
break
async def receive(self) -> PayloadT:
'''
Receive next in order msg
'''
if self.closed:
raise trio.ClosedResourceError
if self._receive_lock.locked():
raise trio.BusyResourceError
async with self._receive_lock:
return await self._rchan.receive()
async def __aenter__(self):
self._is_closed = False
self._chanmngr.open()
return self
async def aclose(self) -> None:
if self.closed:
return
await self._chanmngr.close()
await self._schan.aclose()
await self._rchan.aclose()
self._is_closed = True
'''
Actor module for managing publisher & subscriber channels remotely through
`tractor.context` rpc
'''
@dataclass
class PublisherEntry:
publisher: RingBufferPublisher | None = None
is_set: trio.Event = trio.Event()
_publishers: dict[str, PublisherEntry] = {}
def maybe_init_publisher(topic: str) -> PublisherEntry:
entry = _publishers.get(topic, None)
if not entry:
entry = PublisherEntry()
_publishers[topic] = entry
return entry
def set_publisher(topic: str, pub: RingBufferPublisher):
global _publishers
entry = _publishers.get(topic, None)
if not entry:
entry = maybe_init_publisher(topic)
if entry.publisher:
raise RuntimeError(
f'publisher for topic {topic} already set on {tractor.current_actor()}'
)
entry.publisher = pub
entry.is_set.set()
def get_publisher(topic: str = 'default') -> RingBufferPublisher:
entry = _publishers.get(topic, None)
if not entry or not entry.publisher:
raise RuntimeError(
f'{tractor.current_actor()} tried to get publisher'
'but it\'s not set'
)
return entry.publisher
async def wait_publisher(topic: str) -> RingBufferPublisher:
entry = maybe_init_publisher(topic)
await entry.is_set.wait()
return entry.publisher
@tractor.context
async def _add_pub_channel(
ctx: tractor.Context,
topic: str,
token: RBToken
):
publisher = await wait_publisher(topic)
await publisher.add_channel(token)
@tractor.context
async def _remove_pub_channel(
ctx: tractor.Context,
topic: str,
ring_name: str
):
publisher = await wait_publisher(topic)
maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token:
await publisher.remove_channel(ring_name)
@acm
async def open_pub_channel_at(
actor_name: str,
token: RBToken,
topic: str = 'default',
):
async with tractor.find_actor(actor_name) as portal:
await portal.run(_add_pub_channel, topic=topic, token=token)
try:
yield
except trio.Cancelled:
log.warning(
'open_pub_channel_at got cancelled!\n'
f'\tactor_name = {actor_name}\n'
f'\ttoken = {token}\n'
)
raise
await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
@dataclass
class SubscriberEntry:
subscriber: RingBufferSubscriber | None = None
is_set: trio.Event = trio.Event()
_subscribers: dict[str, SubscriberEntry] = {}
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
entry = _subscribers.get(topic, None)
if not entry:
entry = SubscriberEntry()
_subscribers[topic] = entry
return entry
def set_subscriber(topic: str, sub: RingBufferSubscriber):
global _subscribers
entry = _subscribers.get(topic, None)
if not entry:
entry = maybe_init_subscriber(topic)
if entry.subscriber:
raise RuntimeError(
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
)
entry.subscriber = sub
entry.is_set.set()
def get_subscriber(topic: str = 'default') -> RingBufferSubscriber:
entry = _subscribers.get(topic, None)
if not entry or not entry.subscriber:
raise RuntimeError(
f'{tractor.current_actor()} tried to get subscriber'
'but it\'s not set'
)
return entry.subscriber
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
entry = maybe_init_subscriber(topic)
await entry.is_set.wait()
return entry.subscriber
@tractor.context
async def _add_sub_channel(
ctx: tractor.Context,
topic: str,
token: RBToken
):
subscriber = await wait_subscriber(topic)
await subscriber.add_channel(token)
@tractor.context
async def _remove_sub_channel(
ctx: tractor.Context,
topic: str,
ring_name: str
):
subscriber = await wait_subscriber(topic)
maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token:
await subscriber.remove_channel(ring_name)
@acm
async def open_sub_channel_at(
actor_name: str,
token: RBToken,
topic: str = 'default',
):
async with tractor.find_actor(actor_name) as portal:
await portal.run(_add_sub_channel, topic=topic, token=token)
try:
yield
except trio.Cancelled:
log.warning(
'open_sub_channel_at got cancelled!\n'
f'\tactor_name = {actor_name}\n'
f'\ttoken = {token}\n'
)
raise
await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
'''
High level helpers to open publisher & subscriber
'''
@acm
async def open_ringbuf_publisher(
# name to distinguish this publisher
topic: str = 'default',
# global batch size for channels
batch_size: int = 1,
# messages before changing output channel
msgs_per_turn: int = 1,
encoder: Encoder | None = None,
# ensure subscriber receives in same order publisher sent
# causes it to use wrapped payloads which contain the og
# index
guarantee_order: bool = False,
# on creation, set the `_publisher` global in order to use the provided
# tractor.context & helper utils for adding and removing new channels from
# remote actors
set_module_var: bool = True
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf publisher
'''
async with (
trio.open_nursery(strict_exception_groups=False) as n,
RingBufferPublisher(
n,
batch_size=batch_size,
encoder=encoder,
) as publisher
):
if guarantee_order:
order_send_channel(publisher)
if set_module_var:
set_publisher(topic, publisher)
yield publisher
n.cancel_scope.cancel()
@acm
async def open_ringbuf_subscriber(
# name to distinguish this subscriber
topic: str = 'default',
decoder: Decoder | None = None,
# expect indexed payloads and unwrap them in order
guarantee_order: bool = False,
# on creation, set the `_subscriber` global in order to use the provided
# tractor.context & helper utils for adding and removing new channels from
# remote actors
set_module_var: bool = True
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf subscriber
'''
async with (
trio.open_nursery(strict_exception_groups=False) as n,
RingBufferSubscriber(n, decoder=decoder) as subscriber
):
# maybe monkey patch `.receive` to use indexed payloads
if guarantee_order:
order_receive_channel(subscriber)
# maybe set global module var for remote actor channel updates
if set_module_var:
set_subscriber(topic, subscriber)
yield subscriber
n.cancel_scope.cancel()

View File

@ -78,7 +78,7 @@ class MsgTransport(Protocol):
# eventual msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol
stream: trio.SocketStream
stream: trio.abc.Stream
drained: list[MsgType]
address_type: ClassVar[Type[Address]]

View File

@ -0,0 +1,15 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -0,0 +1,316 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
cpython impl:
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
'''
import os
import array
import tempfile
from uuid import uuid4
from pathlib import Path
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm
import trio
import tractor
from trio import socket
log = tractor.log.get_logger(__name__)
class FDSharingError(Exception):
...
@acm
async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
'''
Async trio reimplementation of `multiprocessing.reduction.sendfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
It's implemented using an async context manager in order to simplyfy usage
with `tractor.context`s, we can open a context in a remote actor that uses
this acm inside of it, and uses `ctx.started()` to signal the original
caller actor to perform the `recv_fds` call.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await sock.bind(sock_path)
sock.listen(1)
yield # socket is setup, ready for receiver connect
# wait until receiver connects
conn, _ = await sock.accept()
# setup int array for fds
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256, acting as a fd amount
# verification on `recv_fds` we refer to it as `check_byte`
msg = bytes([len(fds) % 256])
# send msg with custom SCM_RIGHTS type
await conn.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
)
# finally wait receiver ack
if await conn.recv(1) != b'A':
raise FDSharingError('did not receive acknowledgement of fd')
conn.close()
sock.close()
os.unlink(sock_path)
async def recv_fds(sock_path: str, amount: int) -> tuple:
'''
Async trio reimplementation of `multiprocessing.reduction.recvfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
It's equivalent to std just using `trio.open_unix_socket` for connecting and
changes on error handling.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
stream = await trio.open_unix_socket(sock_path)
sock = stream.socket
# prepare int array for fds
a = array.array('i')
bytes_size = a.itemsize * amount
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
msg, ancdata, flags, addr = await sock.recvmsg(
1, socket.CMSG_SPACE(bytes_size)
)
# maybe failed to receive msg?
if not msg and not ancdata:
raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
# send ack, std comment mentions this ack pattern was to get around an
# old macosx bug, but they are not sure if its necesary any more, in
# any case its not a bad pattern to keep
await sock.send(b'A') # Ack
# expect to receive only one `ancdata` item
if len(ancdata) != 1:
raise FDSharingError(
f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
)
# unpack SCM_RIGHTS msg
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check proper msg type
if cmsg_level != socket.SOL_SOCKET:
raise FDSharingError(
f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
)
if cmsg_type != socket.SCM_RIGHTS:
raise FDSharingError(
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
)
# check proper data alignment
length = len(cmsg_data)
if length % a.itemsize != 0:
raise FDSharingError(
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
)
# attempt to cast as int array
a.frombytes(cmsg_data)
# validate length check byte
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
recvd_check_byte = msg[0] # actual received check byte
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
if recvd_check_byte != payload_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
)
if valid_check_byte != recvd_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
)
return tuple(a)
'''
Share FD actor module
Add "tractor.linux._fdshare" to enabled modules on actors to allow sharing of
FDs with other actors.
Use `share_fds` function to register a set of fds with a name, then other
actors can use `request_fds_from` function to retrieve the fds.
Use `unshare_fds` to disable sharing of a set of FDs.
'''
FDType = tuple[int]
_fds: dict[str, FDType] = {}
def maybe_get_fds(name: str) -> FDType | None:
'''
Get registered FDs with a given name or return None
'''
return _fds.get(name, None)
def get_fds(name: str) -> FDType:
'''
Get registered FDs with a given name or raise
'''
fds = maybe_get_fds(name)
if not fds:
raise RuntimeError(f'No FDs with name {name} found!')
return fds
def share_fds(
name: str,
fds: tuple[int],
) -> None:
'''
Register a set of fds to be shared under a given name.
'''
this_actor = tractor.current_actor()
if __name__ not in this_actor.enable_modules:
raise RuntimeError(
f'Tried to share FDs {fds} with name {name}, but '
f'module {__name__} is not enabled in actor {this_actor.name}!'
)
maybe_fds = maybe_get_fds(name)
if maybe_fds:
raise RuntimeError(f'share FDs: {maybe_fds} already tied to name {name}')
_fds[name] = fds
def unshare_fds(name: str) -> None:
'''
Unregister a set of fds to disable sharing them.
'''
get_fds(name) # raise if not exists
del _fds[name]
@tractor.context
async def _pass_fds(
ctx: tractor.Context,
name: str,
sock_path: str
) -> None:
'''
Endpoint to request a set of FDs from current actor, will use `ctx.started`
to send original FDs, then `send_fds` will block until remote side finishes
the `recv_fds` call.
'''
# get fds or raise error
fds = get_fds(name)
# start fd passing context using socket on `sock_path`
async with send_fds(fds, sock_path):
# send original fds through ctx.started
await ctx.started(fds)
async def request_fds_from(
actor_name: str,
fds_name: str
) -> FDType:
'''
Use this function to retreive shared FDs from `actor_name`.
'''
this_actor = tractor.current_actor()
# create a temporary path for the UDS sock
sock_path = str(
Path(tempfile.gettempdir())
/
f'{fds_name}-from-{actor_name}-to-{this_actor.name}.sock'
)
# having a socket path length > 100 aprox can cause:
# OSError: AF_UNIX path too long
# https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html#tag_13_67_04
# attempt sock path creation with smaller names
if len(sock_path) > 100:
sock_path = str(
Path(tempfile.gettempdir())
/
f'{fds_name}-to-{this_actor.name}.sock'
)
if len(sock_path) > 100:
# just use uuid4
sock_path = str(
Path(tempfile.gettempdir())
/
f'pass-fds-{uuid4()}.sock'
)
async with (
tractor.find_actor(actor_name) as portal,
portal.open_context(
_pass_fds,
name=fds_name,
sock_path=sock_path
) as (ctx, fds_info),
):
# get original FDs
og_fds = fds_info
# retrieve copies of FDs
fds = await recv_fds(sock_path, len(og_fds))
log.info(
f'{this_actor.name} received fds: {og_fds} -> {fds}'
)
return fds

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Linux specifics, for now we are only exposing EventFD
Expose libc eventfd APIs
'''
import os
@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
raise OSError(errno.errorcode[ffi.errno], 'close failed')
class EFDReadCancelled(Exception):
...
class EventFD:
'''
Use a previously opened eventfd(2), meant to be used in
@ -124,26 +128,82 @@ class EventFD:
self._fd: int = fd
self._omode: str = omode
self._fobj = None
self._cscope: trio.CancelScope | None = None
self._is_closed: bool = True
self._read_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property
def fd(self) -> int | None:
return self._fd
def write(self, value: int) -> int:
if self.closed:
raise trio.ClosedResourceError
return write_eventfd(self._fd, value)
async def read(self) -> int:
return await trio.to_thread.run_sync(
read_eventfd, self._fd,
abandon_on_cancel=True
)
'''
Async wrapper for `read_eventfd(self.fd)`
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
in order to make it cancellable when `self.close()` is called.
'''
if self.closed:
raise trio.ClosedResourceError
if self._read_lock.locked():
raise trio.BusyResourceError
async with self._read_lock:
self._cscope = trio.CancelScope()
with self._cscope:
try:
return await trio.to_thread.run_sync(
read_eventfd, self._fd,
abandon_on_cancel=True
)
except OSError as e:
if e.errno != errno.EBADF:
raise
raise trio.BrokenResourceError
if self._cscope.cancelled_caught:
raise EFDReadCancelled
self._cscope = None
def read_nowait(self) -> int:
'''
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
opened with `EFD_NONBLOCK` its gonna block the thread.
'''
return read_eventfd(self._fd)
def open(self):
self._fobj = os.fdopen(self._fd, self._omode)
self._is_closed = False
def close(self):
if self._fobj:
self._fobj.close()
try:
self._fobj.close()
except OSError:
...
if self._cscope:
self._cscope.cancel()
self._is_closed = True
def __enter__(self):
self.open()

View File

@ -32,3 +32,8 @@ from ._broadcast import (
from ._beg import (
collapse_eg as collapse_eg,
)
from ._ordering import (
order_send_channel as order_send_channel,
order_receive_channel as order_receive_channel
)

View File

@ -70,7 +70,8 @@ async def maybe_open_nursery(
yield nursery
else:
async with lib.open_nursery(**kwargs) as nursery:
nursery.cancel_scope.shield = shield
if lib == trio:
nursery.cancel_scope.shield = shield
yield nursery

View File

@ -0,0 +1,108 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Helpers to guarantee ordering of messages through a unordered channel
'''
from __future__ import annotations
from heapq import (
heappush,
heappop
)
import trio
import msgspec
class OrderedPayload(msgspec.Struct, frozen=True):
index: int
payload: bytes
@classmethod
def from_msg(cls, msg: bytes) -> OrderedPayload:
return msgspec.msgpack.decode(msg, type=OrderedPayload)
def encode(self) -> bytes:
return msgspec.msgpack.encode(self)
def order_send_channel(
channel: trio.abc.SendChannel[bytes],
start_index: int = 0
):
next_index = start_index
send_lock = trio.StrictFIFOLock()
channel._send = channel.send
channel._aclose = channel.aclose
async def send(msg: bytes):
nonlocal next_index
async with send_lock:
await channel._send(
OrderedPayload(
index=next_index,
payload=msg
).encode()
)
next_index += 1
async def aclose():
async with send_lock:
await channel._aclose()
channel.send = send
channel.aclose = aclose
def order_receive_channel(
channel: trio.abc.ReceiveChannel[bytes],
start_index: int = 0
):
next_index = start_index
pqueue = []
channel._receive = channel.receive
def can_pop_next() -> bool:
return (
len(pqueue) > 0
and
pqueue[0][0] == next_index
)
async def drain_to_heap():
while not can_pop_next():
msg = await channel._receive()
msg = OrderedPayload.from_msg(msg)
heappush(pqueue, (msg.index, msg.payload))
def pop_next():
nonlocal next_index
_, msg = heappop(pqueue)
next_index += 1
return msg
async def receive() -> bytes:
if can_pop_next():
return pop_next()
await drain_to_heap()
return pop_next()
channel.receive = receive