Linux specific IPC RingBuff using EventFD for async reader wakeup #10
|
@ -0,0 +1,66 @@
|
||||||
|
import trio
|
||||||
|
import pytest
|
||||||
|
from tractor.linux.eventfd import (
|
||||||
|
open_eventfd,
|
||||||
|
EFDReadCancelled,
|
||||||
|
EventFD
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_cancellation():
|
||||||
|
'''
|
||||||
|
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
|
||||||
|
is called.
|
||||||
|
|
||||||
|
'''
|
||||||
|
fd = open_eventfd()
|
||||||
|
|
||||||
|
async def bg_read(event: EventFD):
|
||||||
|
with pytest.raises(EFDReadCancelled):
|
||||||
|
await event.read()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
with (
|
||||||
|
EventFD(fd, 'w') as event,
|
||||||
|
trio.fail_after(3)
|
||||||
|
):
|
||||||
|
n.start_soon(bg_read, event)
|
||||||
|
await trio.sleep(0.2)
|
||||||
|
event.close()
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_trio_semantics():
|
||||||
|
'''
|
||||||
|
Ensure EventFD.read raises trio.ClosedResourceError and
|
||||||
|
trio.BusyResourceError.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
fd = open_eventfd()
|
||||||
|
|
||||||
|
async def bg_read(event: EventFD):
|
||||||
|
try:
|
||||||
|
await event.read()
|
||||||
|
|
||||||
|
except EFDReadCancelled:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
|
||||||
|
# start background read and attempt
|
||||||
|
# foreground read, should be busy
|
||||||
|
with EventFD(fd, 'w') as event:
|
||||||
|
n.start_soon(bg_read, event)
|
||||||
|
await trio.sleep(0.2)
|
||||||
|
with pytest.raises(trio.BusyResourceError):
|
||||||
|
await event.read()
|
||||||
|
|
||||||
|
# attempt read after close
|
||||||
|
with pytest.raises(trio.ClosedResourceError):
|
||||||
|
await event.read()
|
||||||
|
|
||||||
|
trio.run(main)
|
|
@ -0,0 +1,185 @@
|
||||||
|
from typing import AsyncContextManager
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import pytest
|
||||||
|
import tractor
|
||||||
|
|
||||||
|
from tractor.trionics import gather_contexts
|
||||||
|
|
||||||
|
from tractor.ipc._ringbuf import open_ringbufs
|
||||||
|
from tractor.ipc._ringbuf._pubsub import (
|
||||||
|
open_ringbuf_publisher,
|
||||||
|
open_ringbuf_subscriber,
|
||||||
|
get_publisher,
|
||||||
|
get_subscriber,
|
||||||
|
open_pub_channel_at,
|
||||||
|
open_sub_channel_at
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log = tractor.log.get_console_log(level='info')
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def publish_range(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
size: int
|
||||||
|
):
|
||||||
|
pub = get_publisher()
|
||||||
|
await ctx.started()
|
||||||
|
for i in range(size):
|
||||||
|
await pub.send(i.to_bytes(4))
|
||||||
|
log.info(f'sent {i}')
|
||||||
|
|
||||||
|
await pub.flush()
|
||||||
|
|
||||||
|
log.info('range done')
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def subscribe_range(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
size: int
|
||||||
|
):
|
||||||
|
sub = get_subscriber()
|
||||||
|
await ctx.started()
|
||||||
|
|
||||||
|
for i in range(size):
|
||||||
|
recv = int.from_bytes(await sub.receive())
|
||||||
|
if recv != i:
|
||||||
|
raise AssertionError(
|
||||||
|
f'received: {recv} expected: {i}'
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f'received: {recv}')
|
||||||
|
|
||||||
|
log.info('range done')
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def subscriber_child(ctx: tractor.Context):
|
||||||
|
try:
|
||||||
|
async with open_ringbuf_subscriber(guarantee_order=True):
|
||||||
|
await ctx.started()
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
log.info('subscriber exit')
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def publisher_child(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
batch_size: int
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
async with open_ringbuf_publisher(
|
||||||
|
guarantee_order=True,
|
||||||
|
batch_size=batch_size
|
||||||
|
):
|
||||||
|
await ctx.started()
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
log.info('publisher exit')
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_pubsub_test_actors(
|
||||||
|
|
||||||
|
ring_names: list[str],
|
||||||
|
size: int,
|
||||||
|
batch_size: int
|
||||||
|
|
||||||
|
) -> AsyncContextManager[tuple[tractor.Portal, tractor.Portal]]:
|
||||||
|
|
||||||
|
with trio.fail_after(5):
|
||||||
|
async with tractor.open_nursery(
|
||||||
|
enable_modules=[
|
||||||
|
'tractor.linux._fdshare'
|
||||||
|
]
|
||||||
|
) as an:
|
||||||
|
modules = [
|
||||||
|
__name__,
|
||||||
|
'tractor.linux._fdshare',
|
||||||
|
'tractor.ipc._ringbuf._pubsub'
|
||||||
|
]
|
||||||
|
sub_portal = await an.start_actor(
|
||||||
|
'sub',
|
||||||
|
enable_modules=modules
|
||||||
|
)
|
||||||
|
pub_portal = await an.start_actor(
|
||||||
|
'pub',
|
||||||
|
enable_modules=modules
|
||||||
|
)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
sub_portal.open_context(subscriber_child) as (long_rctx, _),
|
||||||
|
pub_portal.open_context(
|
||||||
|
publisher_child,
|
||||||
|
batch_size=batch_size
|
||||||
|
) as (long_sctx, _),
|
||||||
|
|
||||||
|
open_ringbufs(ring_names) as tokens,
|
||||||
|
|
||||||
|
gather_contexts([
|
||||||
|
open_sub_channel_at('sub', ring)
|
||||||
|
for ring in tokens
|
||||||
|
]),
|
||||||
|
gather_contexts([
|
||||||
|
open_pub_channel_at('pub', ring)
|
||||||
|
for ring in tokens
|
||||||
|
]),
|
||||||
|
sub_portal.open_context(subscribe_range, size=size) as (rctx, _),
|
||||||
|
pub_portal.open_context(publish_range, size=size) as (sctx, _)
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
await rctx.wait_for_result()
|
||||||
|
await sctx.wait_for_result()
|
||||||
|
|
||||||
|
await long_sctx.cancel()
|
||||||
|
await long_rctx.cancel()
|
||||||
|
|
||||||
|
await an.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
('ring_names', 'size', 'batch_size'),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
['ring-first'],
|
||||||
|
100,
|
||||||
|
1
|
||||||
|
),
|
||||||
|
(
|
||||||
|
['ring-first'],
|
||||||
|
69,
|
||||||
|
1
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[f'multi-ring-{i}' for i in range(3)],
|
||||||
|
1000,
|
||||||
|
100
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
'simple',
|
||||||
|
'redo-simple',
|
||||||
|
'multi-ring',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_pubsub(
|
||||||
|
request,
|
||||||
|
ring_names: list[str],
|
||||||
|
size: int,
|
||||||
|
batch_size: int
|
||||||
|
):
|
||||||
|
async def main():
|
||||||
|
async with open_pubsub_test_actors(
|
||||||
|
ring_names, size, batch_size
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
trio.run(main)
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
import hashlib
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -6,36 +7,45 @@ import pytest
|
||||||
import tractor
|
import tractor
|
||||||
from tractor.ipc._ringbuf import (
|
from tractor.ipc._ringbuf import (
|
||||||
open_ringbuf,
|
open_ringbuf,
|
||||||
|
open_ringbuf_pair,
|
||||||
|
attach_to_ringbuf_receiver,
|
||||||
|
attach_to_ringbuf_sender,
|
||||||
|
attach_to_ringbuf_channel,
|
||||||
RBToken,
|
RBToken,
|
||||||
RingBuffSender,
|
|
||||||
RingBuffReceiver
|
|
||||||
)
|
)
|
||||||
from tractor._testing.samples import (
|
from tractor._testing.samples import (
|
||||||
generate_sample_messages,
|
generate_single_byte_msgs,
|
||||||
|
RandomBytesGenerator
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case you don't want to melt your cores, uncomment dis!
|
|
||||||
pytestmark = pytest.mark.skip
|
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def child_read_shm(
|
async def child_read_shm(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
msg_amount: int,
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
total_bytes: int,
|
) -> str:
|
||||||
) -> None:
|
'''
|
||||||
recvd_bytes = 0
|
Sub-actor used in `test_ringbuf`.
|
||||||
await ctx.started()
|
|
||||||
start_ts = time.time()
|
|
||||||
async with RingBuffReceiver(token) as receiver:
|
|
||||||
while recvd_bytes < total_bytes:
|
|
||||||
msg = await receiver.receive_some()
|
|
||||||
recvd_bytes += len(msg)
|
|
||||||
|
|
||||||
# make sure we dont hold any memoryviews
|
Attach to a ringbuf and receive all messages until end of stream.
|
||||||
# before the ctx manager aclose()
|
Keep track of how many bytes received and also calculate
|
||||||
msg = None
|
sha256 of the whole byte stream.
|
||||||
|
|
||||||
|
Calculate and print performance stats, finally return calculated
|
||||||
|
hash.
|
||||||
|
|
||||||
|
'''
|
||||||
|
await ctx.started()
|
||||||
|
print('reader started')
|
||||||
|
msg_amount = 0
|
||||||
|
recvd_bytes = 0
|
||||||
|
recvd_hash = hashlib.sha256()
|
||||||
|
start_ts = time.time()
|
||||||
|
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||||
|
async for msg in receiver:
|
||||||
|
msg_amount += 1
|
||||||
|
recvd_hash.update(msg)
|
||||||
|
recvd_bytes += len(msg)
|
||||||
|
|
||||||
end_ts = time.time()
|
end_ts = time.time()
|
||||||
elapsed = end_ts - start_ts
|
elapsed = end_ts - start_ts
|
||||||
|
@ -44,6 +54,10 @@ async def child_read_shm(
|
||||||
print(f'\n\telapsed ms: {elapsed_ms}')
|
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||||
|
print(f'\treceived msgs: {msg_amount:,}')
|
||||||
|
print(f'\treceived bytes: {recvd_bytes:,}')
|
||||||
|
|
||||||
|
return recvd_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
|
@ -52,17 +66,37 @@ async def child_write_shm(
|
||||||
msg_amount: int,
|
msg_amount: int,
|
||||||
rand_min: int,
|
rand_min: int,
|
||||||
rand_max: int,
|
rand_max: int,
|
||||||
token: RBToken,
|
buf_size: int
|
||||||
) -> None:
|
) -> None:
|
||||||
msgs, total_bytes = generate_sample_messages(
|
'''
|
||||||
|
Sub-actor used in `test_ringbuf`
|
||||||
|
|
||||||
|
Generate `msg_amount` payloads with
|
||||||
|
`random.randint(rand_min, rand_max)` random bytes at the end,
|
||||||
|
Calculate sha256 hash and send it to parent on `ctx.started`.
|
||||||
|
|
||||||
|
Attach to ringbuf and send all generated messages.
|
||||||
|
|
||||||
|
'''
|
||||||
|
rng = RandomBytesGenerator(
|
||||||
msg_amount,
|
msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
)
|
)
|
||||||
await ctx.started(total_bytes)
|
async with (
|
||||||
async with RingBuffSender(token) as sender:
|
open_ringbuf('test_ringbuf', buf_size=buf_size) as token,
|
||||||
for msg in msgs:
|
attach_to_ringbuf_sender(token) as sender
|
||||||
await sender.send_all(msg)
|
):
|
||||||
|
await ctx.started(token)
|
||||||
|
print('writer started')
|
||||||
|
for msg in rng:
|
||||||
|
await sender.send(msg)
|
||||||
|
|
||||||
|
if rng.msgs_generated % rng.recommended_log_interval == 0:
|
||||||
|
print(f'wrote {rng.msgs_generated} msgs')
|
||||||
|
|
||||||
|
print('writer exit')
|
||||||
|
return rng.hexdigest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -89,84 +123,91 @@ def test_ringbuf(
|
||||||
rand_max: int,
|
rand_max: int,
|
||||||
buf_size: int
|
buf_size: int
|
||||||
):
|
):
|
||||||
|
'''
|
||||||
|
- Open a new ring buf on root actor
|
||||||
|
- Open `child_write_shm` ctx in sub-actor which will generate a
|
||||||
|
random payload and send its hash on `ctx.started`, finally sending
|
||||||
|
the payload through the stream.
|
||||||
|
- Open `child_read_shm` ctx in sub-actor which will receive the
|
||||||
|
payload, calculate perf stats and return the hash.
|
||||||
|
- Compare both hashes
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf(
|
async with tractor.open_nursery() as an:
|
||||||
'test_ringbuf',
|
send_p = await an.start_actor(
|
||||||
buf_size=buf_size
|
'ring_sender',
|
||||||
) as token:
|
enable_modules=[
|
||||||
proc_kwargs = {
|
__name__,
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
'tractor.linux._fdshare'
|
||||||
}
|
],
|
||||||
|
)
|
||||||
|
recv_p = await an.start_actor(
|
||||||
|
'ring_receiver',
|
||||||
|
enable_modules=[
|
||||||
|
__name__,
|
||||||
|
'tractor.linux._fdshare'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async with (
|
||||||
|
send_p.open_context(
|
||||||
|
child_write_shm,
|
||||||
|
msg_amount=msg_amount,
|
||||||
|
rand_min=rand_min,
|
||||||
|
rand_max=rand_max,
|
||||||
|
buf_size=buf_size
|
||||||
|
) as (sctx, token),
|
||||||
|
|
||||||
common_kwargs = {
|
recv_p.open_context(
|
||||||
'msg_amount': msg_amount,
|
child_read_shm,
|
||||||
'token': token,
|
token=token,
|
||||||
}
|
) as (rctx, _),
|
||||||
async with tractor.open_nursery() as an:
|
):
|
||||||
send_p = await an.start_actor(
|
sent_hash = await sctx.result()
|
||||||
'ring_sender',
|
recvd_hash = await rctx.result()
|
||||||
enable_modules=[__name__],
|
|
||||||
proc_kwargs=proc_kwargs
|
|
||||||
)
|
|
||||||
recv_p = await an.start_actor(
|
|
||||||
'ring_receiver',
|
|
||||||
enable_modules=[__name__],
|
|
||||||
proc_kwargs=proc_kwargs
|
|
||||||
)
|
|
||||||
async with (
|
|
||||||
send_p.open_context(
|
|
||||||
child_write_shm,
|
|
||||||
rand_min=rand_min,
|
|
||||||
rand_max=rand_max,
|
|
||||||
**common_kwargs
|
|
||||||
) as (sctx, total_bytes),
|
|
||||||
recv_p.open_context(
|
|
||||||
child_read_shm,
|
|
||||||
**common_kwargs,
|
|
||||||
total_bytes=total_bytes,
|
|
||||||
) as (sctx, _sent),
|
|
||||||
):
|
|
||||||
await recv_p.result()
|
|
||||||
|
|
||||||
await send_p.cancel_actor()
|
assert sent_hash == recvd_hash
|
||||||
await recv_p.cancel_actor()
|
|
||||||
|
|
||||||
|
await an.cancel()
|
||||||
|
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def child_blocked_receiver(
|
async def child_blocked_receiver(ctx: tractor.Context):
|
||||||
ctx: tractor.Context,
|
async with (
|
||||||
token: RBToken
|
open_ringbuf('test_ring_cancel_reader') as token,
|
||||||
):
|
|
||||||
async with RingBuffReceiver(token) as receiver:
|
attach_to_ringbuf_receiver(token) as receiver
|
||||||
await ctx.started()
|
):
|
||||||
|
await ctx.started(token)
|
||||||
await receiver.receive_some()
|
await receiver.receive_some()
|
||||||
|
|
||||||
|
|
||||||
def test_ring_reader_cancel():
|
def test_reader_cancel():
|
||||||
|
'''
|
||||||
|
Test that a receiver blocked on eventfd(2) read responds to
|
||||||
|
cancellation.
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf('test_ring_cancel_reader') as token:
|
async with tractor.open_nursery() as an:
|
||||||
|
recv_p = await an.start_actor(
|
||||||
|
'ring_blocked_receiver',
|
||||||
|
enable_modules=[
|
||||||
|
__name__,
|
||||||
|
'tractor.linux._fdshare'
|
||||||
|
],
|
||||||
|
)
|
||||||
async with (
|
async with (
|
||||||
tractor.open_nursery() as an,
|
recv_p.open_context(
|
||||||
RingBuffSender(token) as _sender,
|
child_blocked_receiver,
|
||||||
|
) as (sctx, token),
|
||||||
|
|
||||||
|
attach_to_ringbuf_sender(token),
|
||||||
):
|
):
|
||||||
recv_p = await an.start_actor(
|
await trio.sleep(.1)
|
||||||
'ring_blocked_receiver',
|
await an.cancel()
|
||||||
enable_modules=[__name__],
|
|
||||||
proc_kwargs={
|
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
async with (
|
|
||||||
recv_p.open_context(
|
|
||||||
child_blocked_receiver,
|
|
||||||
token=token
|
|
||||||
) as (sctx, _sent),
|
|
||||||
):
|
|
||||||
await trio.sleep(1)
|
|
||||||
await an.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
with pytest.raises(tractor._exceptions.ContextCancelled):
|
with pytest.raises(tractor._exceptions.ContextCancelled):
|
||||||
|
@ -174,38 +215,166 @@ def test_ring_reader_cancel():
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def child_blocked_sender(
|
async def child_blocked_sender(ctx: tractor.Context):
|
||||||
ctx: tractor.Context,
|
async with (
|
||||||
token: RBToken
|
open_ringbuf(
|
||||||
):
|
'test_ring_cancel_sender',
|
||||||
async with RingBuffSender(token) as sender:
|
buf_size=1
|
||||||
await ctx.started()
|
) as token,
|
||||||
|
|
||||||
|
attach_to_ringbuf_sender(token) as sender
|
||||||
|
):
|
||||||
|
await ctx.started(token)
|
||||||
await sender.send_all(b'this will wrap')
|
await sender.send_all(b'this will wrap')
|
||||||
|
|
||||||
|
|
||||||
def test_ring_sender_cancel():
|
def test_sender_cancel():
|
||||||
|
'''
|
||||||
|
Test that a sender blocked on eventfd(2) read responds to
|
||||||
|
cancellation.
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf(
|
async with tractor.open_nursery() as an:
|
||||||
'test_ring_cancel_sender',
|
recv_p = await an.start_actor(
|
||||||
buf_size=1
|
'ring_blocked_sender',
|
||||||
) as token:
|
enable_modules=[
|
||||||
async with tractor.open_nursery() as an:
|
__name__,
|
||||||
recv_p = await an.start_actor(
|
'tractor.linux._fdshare'
|
||||||
'ring_blocked_sender',
|
],
|
||||||
enable_modules=[__name__],
|
)
|
||||||
proc_kwargs={
|
async with (
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
recv_p.open_context(
|
||||||
}
|
child_blocked_sender,
|
||||||
)
|
) as (sctx, token),
|
||||||
async with (
|
|
||||||
recv_p.open_context(
|
attach_to_ringbuf_receiver(token)
|
||||||
child_blocked_sender,
|
):
|
||||||
token=token
|
await trio.sleep(.1)
|
||||||
) as (sctx, _sent),
|
await an.cancel()
|
||||||
):
|
|
||||||
await trio.sleep(1)
|
|
||||||
await an.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
with pytest.raises(tractor._exceptions.ContextCancelled):
|
with pytest.raises(tractor._exceptions.ContextCancelled):
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_receiver_max_bytes():
|
||||||
|
'''
|
||||||
|
Test that RingBuffReceiver.receive_some's max_bytes optional
|
||||||
|
argument works correctly, send a msg of size 100, then
|
||||||
|
force receive of messages with max_bytes == 1, wait until
|
||||||
|
100 of these messages are received, then compare join of
|
||||||
|
msgs with original message
|
||||||
|
|
||||||
|
'''
|
||||||
|
msg = generate_single_byte_msgs(100)
|
||||||
|
msgs = []
|
||||||
|
|
||||||
|
rb_common = {
|
||||||
|
'cleanup': False,
|
||||||
|
'is_ipc': False
|
||||||
|
}
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with (
|
||||||
|
open_ringbuf(
|
||||||
|
'test_ringbuf_max_bytes',
|
||||||
|
buf_size=10,
|
||||||
|
is_ipc=False
|
||||||
|
) as token,
|
||||||
|
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
|
||||||
|
attach_to_ringbuf_sender(token, **rb_common) as sender,
|
||||||
|
|
||||||
|
attach_to_ringbuf_receiver(token, **rb_common) as receiver
|
||||||
|
):
|
||||||
|
async def _send_and_close():
|
||||||
|
await sender.send_all(msg)
|
||||||
|
await sender.aclose()
|
||||||
|
|
||||||
|
n.start_soon(_send_and_close)
|
||||||
|
while len(msgs) < len(msg):
|
||||||
|
msg_part = await receiver.receive_some(max_bytes=1)
|
||||||
|
assert len(msg_part) == 1
|
||||||
|
msgs.append(msg_part)
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
assert msg == b''.join(msgs)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def child_channel_sender(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
msg_amount_min: int,
|
||||||
|
msg_amount_max: int,
|
||||||
|
token_in: RBToken,
|
||||||
|
token_out: RBToken
|
||||||
|
):
|
||||||
|
import random
|
||||||
|
rng = RandomBytesGenerator(
|
||||||
|
random.randint(msg_amount_min, msg_amount_max),
|
||||||
|
rand_min=256,
|
||||||
|
rand_max=1024,
|
||||||
|
)
|
||||||
|
async with attach_to_ringbuf_channel(
|
||||||
|
token_in,
|
||||||
|
token_out
|
||||||
|
) as chan:
|
||||||
|
await ctx.started()
|
||||||
|
for msg in rng:
|
||||||
|
await chan.send(msg)
|
||||||
|
|
||||||
|
await chan.send(b'bye')
|
||||||
|
await chan.receive()
|
||||||
|
return rng.hexdigest
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel():
|
||||||
|
|
||||||
|
msg_amount_min = 100
|
||||||
|
msg_amount_max = 1000
|
||||||
|
|
||||||
|
mods = [
|
||||||
|
__name__,
|
||||||
|
'tractor.linux._fdshare'
|
||||||
|
]
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with (
|
||||||
|
tractor.open_nursery(enable_modules=mods) as an,
|
||||||
|
|
||||||
|
open_ringbuf_pair(
|
||||||
|
'test_ringbuf_transport'
|
||||||
|
) as (send_token, recv_token),
|
||||||
|
|
||||||
|
attach_to_ringbuf_channel(send_token, recv_token) as chan,
|
||||||
|
):
|
||||||
|
sender = await an.start_actor(
|
||||||
|
'test_ringbuf_transport_sender',
|
||||||
|
enable_modules=mods,
|
||||||
|
)
|
||||||
|
async with (
|
||||||
|
sender.open_context(
|
||||||
|
child_channel_sender,
|
||||||
|
msg_amount_min=msg_amount_min,
|
||||||
|
msg_amount_max=msg_amount_max,
|
||||||
|
token_in=recv_token,
|
||||||
|
token_out=send_token
|
||||||
|
) as (ctx, _),
|
||||||
|
):
|
||||||
|
recvd_hash = hashlib.sha256()
|
||||||
|
async for msg in chan:
|
||||||
|
if msg == b'bye':
|
||||||
|
await chan.send(b'bye')
|
||||||
|
break
|
||||||
|
|
||||||
|
recvd_hash.update(msg)
|
||||||
|
|
||||||
|
sent_hash = await ctx.result()
|
||||||
|
|
||||||
|
assert recvd_hash.hexdigest() == sent_hash
|
||||||
|
|
||||||
|
await an.cancel()
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
|
@ -121,9 +121,14 @@ def get_peer_by_name(
|
||||||
actor: Actor = current_actor()
|
actor: Actor = current_actor()
|
||||||
server: IPCServer = actor.ipc_server
|
server: IPCServer = actor.ipc_server
|
||||||
to_scan: dict[tuple, list[Channel]] = server._peers.copy()
|
to_scan: dict[tuple, list[Channel]] = server._peers.copy()
|
||||||
pchan: Channel|None = actor._parent_chan
|
|
||||||
if pchan:
|
# TODO: is this ever needed? creates a duplicate channel on actor._peers
|
||||||
to_scan[pchan.uid].append(pchan)
|
# when multiple find_actor calls are made to same actor from a single ctx
|
||||||
|
# which causes actor exit to hang waiting forever on
|
||||||
|
# `actor._no_more_peers.wait()` in `_runtime.async_main`
|
||||||
|
# pchan: Channel|None = actor._parent_chan
|
||||||
|
# if pchan:
|
||||||
|
# to_scan[pchan.uid].append(pchan)
|
||||||
|
|
||||||
for aid, chans in to_scan.items():
|
for aid, chans in to_scan.items():
|
||||||
_, peer_name = aid
|
_, peer_name = aid
|
||||||
|
|
|
@ -1,35 +1,99 @@
|
||||||
import os
|
import hashlib
|
||||||
import random
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_messages(
|
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||||
amount: int,
|
'''
|
||||||
rand_min: int = 0,
|
Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
|
||||||
rand_max: int = 0,
|
|
||||||
silent: bool = False
|
|
||||||
) -> tuple[list[bytes], int]:
|
|
||||||
|
|
||||||
msgs = []
|
'''
|
||||||
size = 0
|
# array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
|
||||||
|
# up to 9->'9'(57).
|
||||||
|
arr = np.arange(amount, dtype=np.uint8) % 10
|
||||||
|
# move into ascii space
|
||||||
|
arr += 48
|
||||||
|
return arr.tobytes()
|
||||||
|
|
||||||
if not silent:
|
|
||||||
print(f'\ngenerating {amount} messages...')
|
|
||||||
|
|
||||||
for i in range(amount):
|
class RandomBytesGenerator:
|
||||||
msg = f'[{i:08}]'.encode('utf-8')
|
'''
|
||||||
|
Generate bytes msgs for tests.
|
||||||
|
|
||||||
if rand_max > 0:
|
messages will have the following format:
|
||||||
msg += os.urandom(
|
|
||||||
random.randint(rand_min, rand_max))
|
|
||||||
|
|
||||||
size += len(msg)
|
b'[{i:08}]' + random_bytes
|
||||||
|
|
||||||
msgs.append(msg)
|
so for message index 25:
|
||||||
|
|
||||||
if not silent and i and i % 10_000 == 0:
|
b'[00000025]' + random_bytes
|
||||||
print(f'{i} generated')
|
|
||||||
|
|
||||||
if not silent:
|
also generates sha256 hash of msgs.
|
||||||
print(f'done, {size:,} bytes in total')
|
|
||||||
|
|
||||||
return msgs, size
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
amount: int,
|
||||||
|
rand_min: int = 0,
|
||||||
|
rand_max: int = 0
|
||||||
|
):
|
||||||
|
if rand_max < rand_min:
|
||||||
|
raise ValueError('rand_max must be >= rand_min')
|
||||||
|
|
||||||
|
self._amount = amount
|
||||||
|
self._rand_min = rand_min
|
||||||
|
self._rand_max = rand_max
|
||||||
|
self._index = 0
|
||||||
|
self._hasher = hashlib.sha256()
|
||||||
|
self._total_bytes = 0
|
||||||
|
|
||||||
|
self._lengths = np.random.randint(
|
||||||
|
rand_min,
|
||||||
|
rand_max + 1,
|
||||||
|
size=amount,
|
||||||
|
dtype=np.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> bytes:
|
||||||
|
if self._index == self._amount:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
header = f'[{self._index:08}]'.encode('utf-8')
|
||||||
|
|
||||||
|
length = int(self._lengths[self._index])
|
||||||
|
msg = header + np.random.bytes(length)
|
||||||
|
|
||||||
|
self._hasher.update(msg)
|
||||||
|
self._total_bytes += length
|
||||||
|
self._index += 1
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hexdigest(self) -> str:
|
||||||
|
return self._hasher.hexdigest()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_bytes(self) -> int:
|
||||||
|
return self._total_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_msgs(self) -> int:
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msgs_generated(self) -> int:
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recommended_log_interval(self) -> int:
|
||||||
|
max_msg_size = 10 + self._rand_max
|
||||||
|
|
||||||
|
if max_msg_size <= 32 * 1024:
|
||||||
|
return 10_000
|
||||||
|
|
||||||
|
else:
|
||||||
|
return 1000
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
|
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
A modular IPC layer supporting the power of cross-process SC!
|
A modular IPC layer supporting the power of cross-process SC!
|
||||||
|
|
||||||
|
|
|
@ -1,253 +0,0 @@
|
||||||
# tractor: structured concurrent "actors".
|
|
||||||
# Copyright 2018-eternity Tyler Goodlet.
|
|
||||||
|
|
||||||
# This program is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
|
||||||
# the Free Software Foundation, either version 3 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
|
|
||||||
# This program is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
'''
|
|
||||||
IPC Reliable RingBuffer implementation
|
|
||||||
|
|
||||||
'''
|
|
||||||
from __future__ import annotations
|
|
||||||
from contextlib import contextmanager as cm
|
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
|
||||||
|
|
||||||
import trio
|
|
||||||
from msgspec import (
|
|
||||||
Struct,
|
|
||||||
to_builtins
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._linux import (
|
|
||||||
EFD_NONBLOCK,
|
|
||||||
open_eventfd,
|
|
||||||
EventFD
|
|
||||||
)
|
|
||||||
from ._mp_bs import disable_mantracker
|
|
||||||
|
|
||||||
|
|
||||||
disable_mantracker()
|
|
||||||
|
|
||||||
|
|
||||||
class RBToken(Struct, frozen=True):
|
|
||||||
'''
|
|
||||||
RingBuffer token contains necesary info to open the two
|
|
||||||
eventfds and the shared memory
|
|
||||||
|
|
||||||
'''
|
|
||||||
shm_name: str
|
|
||||||
write_eventfd: int
|
|
||||||
wrap_eventfd: int
|
|
||||||
buf_size: int
|
|
||||||
|
|
||||||
def as_msg(self):
|
|
||||||
return to_builtins(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_msg(cls, msg: dict) -> RBToken:
|
|
||||||
if isinstance(msg, RBToken):
|
|
||||||
return msg
|
|
||||||
|
|
||||||
return RBToken(**msg)
|
|
||||||
|
|
||||||
|
|
||||||
@cm
|
|
||||||
def open_ringbuf(
|
|
||||||
shm_name: str,
|
|
||||||
buf_size: int = 10 * 1024,
|
|
||||||
write_efd_flags: int = 0,
|
|
||||||
wrap_efd_flags: int = 0
|
|
||||||
) -> RBToken:
|
|
||||||
shm = SharedMemory(
|
|
||||||
name=shm_name,
|
|
||||||
size=buf_size,
|
|
||||||
create=True
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
token = RBToken(
|
|
||||||
shm_name=shm_name,
|
|
||||||
write_eventfd=open_eventfd(flags=write_efd_flags),
|
|
||||||
wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
|
|
||||||
buf_size=buf_size
|
|
||||||
)
|
|
||||||
yield token
|
|
||||||
|
|
||||||
finally:
|
|
||||||
shm.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffSender(trio.abc.SendStream):
|
|
||||||
'''
|
|
||||||
IPC Reliable Ring Buffer sender side implementation
|
|
||||||
|
|
||||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
|
||||||
writes to the reader.
|
|
||||||
|
|
||||||
'''
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token: RBToken,
|
|
||||||
start_ptr: int = 0,
|
|
||||||
):
|
|
||||||
token = RBToken.from_msg(token)
|
|
||||||
self._shm = SharedMemory(
|
|
||||||
name=token.shm_name,
|
|
||||||
size=token.buf_size,
|
|
||||||
create=False
|
|
||||||
)
|
|
||||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
|
||||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
|
||||||
self._ptr = start_ptr
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key(self) -> str:
|
|
||||||
return self._shm.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
return self._shm.size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ptr(self) -> int:
|
|
||||||
return self._ptr
|
|
||||||
|
|
||||||
@property
|
|
||||||
def write_fd(self) -> int:
|
|
||||||
return self._write_event.fd
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wrap_fd(self) -> int:
|
|
||||||
return self._wrap_event.fd
|
|
||||||
|
|
||||||
async def send_all(self, data: bytes | bytearray | memoryview):
|
|
||||||
# while data is larger than the remaining buf
|
|
||||||
target_ptr = self.ptr + len(data)
|
|
||||||
while target_ptr > self.size:
|
|
||||||
# write all bytes that fit
|
|
||||||
remaining = self.size - self.ptr
|
|
||||||
self._shm.buf[self.ptr:] = data[:remaining]
|
|
||||||
# signal write and wait for reader wrap around
|
|
||||||
self._write_event.write(remaining)
|
|
||||||
await self._wrap_event.read()
|
|
||||||
|
|
||||||
# wrap around and trim already written bytes
|
|
||||||
self._ptr = 0
|
|
||||||
data = data[remaining:]
|
|
||||||
target_ptr = self._ptr + len(data)
|
|
||||||
|
|
||||||
# remaining data fits on buffer
|
|
||||||
self._shm.buf[self.ptr:target_ptr] = data
|
|
||||||
self._write_event.write(len(data))
|
|
||||||
self._ptr = target_ptr
|
|
||||||
|
|
||||||
async def wait_send_all_might_not_block(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def aclose(self):
|
|
||||||
self._write_event.close()
|
|
||||||
self._wrap_event.close()
|
|
||||||
self._shm.close()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
self._write_event.open()
|
|
||||||
self._wrap_event.open()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffReceiver(trio.abc.ReceiveStream):
|
|
||||||
'''
|
|
||||||
IPC Reliable Ring Buffer receiver side implementation
|
|
||||||
|
|
||||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
|
||||||
writes to the reader.
|
|
||||||
|
|
||||||
'''
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token: RBToken,
|
|
||||||
start_ptr: int = 0,
|
|
||||||
flags: int = 0
|
|
||||||
):
|
|
||||||
token = RBToken.from_msg(token)
|
|
||||||
self._shm = SharedMemory(
|
|
||||||
name=token.shm_name,
|
|
||||||
size=token.buf_size,
|
|
||||||
create=False
|
|
||||||
)
|
|
||||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
|
||||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
|
||||||
self._ptr = start_ptr
|
|
||||||
self._flags = flags
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key(self) -> str:
|
|
||||||
return self._shm.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size(self) -> int:
|
|
||||||
return self._shm.size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ptr(self) -> int:
|
|
||||||
return self._ptr
|
|
||||||
|
|
||||||
@property
|
|
||||||
def write_fd(self) -> int:
|
|
||||||
return self._write_event.fd
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wrap_fd(self) -> int:
|
|
||||||
return self._wrap_event.fd
|
|
||||||
|
|
||||||
async def receive_some(
|
|
||||||
self,
|
|
||||||
max_bytes: int | None = None,
|
|
||||||
nb_timeout: float = 0.1
|
|
||||||
) -> memoryview:
|
|
||||||
# if non blocking eventfd enabled, do polling
|
|
||||||
# until next write, this allows signal handling
|
|
||||||
if self._flags | EFD_NONBLOCK:
|
|
||||||
delta = None
|
|
||||||
while delta is None:
|
|
||||||
try:
|
|
||||||
delta = await self._write_event.read()
|
|
||||||
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno == 'EAGAIN':
|
|
||||||
continue
|
|
||||||
|
|
||||||
raise e
|
|
||||||
|
|
||||||
else:
|
|
||||||
delta = await self._write_event.read()
|
|
||||||
|
|
||||||
# fetch next segment and advance ptr
|
|
||||||
next_ptr = self._ptr + delta
|
|
||||||
segment = self._shm.buf[self._ptr:next_ptr]
|
|
||||||
self._ptr = next_ptr
|
|
||||||
|
|
||||||
if self.ptr == self.size:
|
|
||||||
# reached the end, signal wrap around
|
|
||||||
self._ptr = 0
|
|
||||||
self._wrap_event.write(1)
|
|
||||||
|
|
||||||
return segment
|
|
||||||
|
|
||||||
async def aclose(self):
|
|
||||||
self._write_event.close()
|
|
||||||
self._wrap_event.close()
|
|
||||||
self._shm.close()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
self._write_event.open()
|
|
||||||
self._wrap_event.open()
|
|
||||||
return self
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,834 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
'''
|
||||||
|
Ring buffer ipc publish-subscribe mechanism brokered by ringd
|
||||||
|
can dynamically add new outputs (publisher) or inputs (subscriber)
|
||||||
|
'''
|
||||||
|
from typing import (
|
||||||
|
TypeVar,
|
||||||
|
Generic,
|
||||||
|
Callable,
|
||||||
|
Awaitable,
|
||||||
|
AsyncContextManager
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import tractor
|
||||||
|
|
||||||
|
from msgspec.msgpack import (
|
||||||
|
Encoder,
|
||||||
|
Decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
from tractor.ipc._ringbuf import (
|
||||||
|
RBToken,
|
||||||
|
PayloadT,
|
||||||
|
RingBufferSendChannel,
|
||||||
|
RingBufferReceiveChannel,
|
||||||
|
attach_to_ringbuf_sender,
|
||||||
|
attach_to_ringbuf_receiver
|
||||||
|
)
|
||||||
|
|
||||||
|
from tractor.trionics import (
|
||||||
|
order_send_channel,
|
||||||
|
order_receive_channel
|
||||||
|
)
|
||||||
|
|
||||||
|
import tractor.linux._fdshare as fdshare
|
||||||
|
|
||||||
|
|
||||||
|
log = tractor.log.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ChannelType = TypeVar('ChannelType')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChannelInfo:
|
||||||
|
token: RBToken
|
||||||
|
channel: ChannelType
|
||||||
|
cancel_scope: trio.CancelScope
|
||||||
|
teardown: trio.Event
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelManager(Generic[ChannelType]):
|
||||||
|
'''
|
||||||
|
Helper for managing channel resources and their handler tasks with
|
||||||
|
cancellation, add or remove channels dynamically!
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# nursery used to spawn channel handler tasks
|
||||||
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
# acm will be used for setup & teardown of channel resources
|
||||||
|
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
|
||||||
|
|
||||||
|
# long running bg task to handle channel
|
||||||
|
channel_task: Callable[..., Awaitable[None]]
|
||||||
|
):
|
||||||
|
self._n = n
|
||||||
|
self._open_channel = open_channel_acm
|
||||||
|
self._channel_task = channel_task
|
||||||
|
|
||||||
|
# signal when a new channel conects and we previously had none
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
|
# store channel runtime variables
|
||||||
|
self._channels: list[ChannelInfo] = []
|
||||||
|
|
||||||
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._is_closed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
async def _channel_handler_task(
|
||||||
|
self,
|
||||||
|
token: RBToken,
|
||||||
|
task_status=trio.TASK_STATUS_IGNORED,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Open channel resources, add to internal data structures, signal channel
|
||||||
|
connect through trio.Event, and run `channel_task` with cancel scope,
|
||||||
|
and finally, maybe remove channel from internal data structures.
|
||||||
|
|
||||||
|
Spawned by `add_channel` function, lock is held from begining of fn
|
||||||
|
until `task_status.started()` call.
|
||||||
|
|
||||||
|
kwargs are proxied to `self._open_channel` acm.
|
||||||
|
'''
|
||||||
|
async with self._open_channel(
|
||||||
|
token,
|
||||||
|
**kwargs
|
||||||
|
) as chan:
|
||||||
|
cancel_scope = trio.CancelScope()
|
||||||
|
info = ChannelInfo(
|
||||||
|
token=token,
|
||||||
|
channel=chan,
|
||||||
|
cancel_scope=cancel_scope,
|
||||||
|
teardown=trio.Event()
|
||||||
|
)
|
||||||
|
self._channels.append(info)
|
||||||
|
|
||||||
|
if len(self) == 1:
|
||||||
|
self._connect_event.set()
|
||||||
|
|
||||||
|
task_status.started()
|
||||||
|
|
||||||
|
with cancel_scope:
|
||||||
|
await self._channel_task(info)
|
||||||
|
|
||||||
|
self._maybe_destroy_channel(token.shm_name)
|
||||||
|
|
||||||
|
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||||
|
'''
|
||||||
|
Given a channel name maybe return its index and value from
|
||||||
|
internal _channels list.
|
||||||
|
|
||||||
|
Only use after acquiring lock.
|
||||||
|
'''
|
||||||
|
for entry in enumerate(self._channels):
|
||||||
|
i, info = entry
|
||||||
|
if info.token.shm_name == name:
|
||||||
|
return entry
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_destroy_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
If channel exists cancel its scope and remove from internal
|
||||||
|
_channels list.
|
||||||
|
|
||||||
|
'''
|
||||||
|
maybe_entry = self._find_channel(name)
|
||||||
|
if maybe_entry:
|
||||||
|
i, info = maybe_entry
|
||||||
|
info.cancel_scope.cancel()
|
||||||
|
info.teardown.set()
|
||||||
|
del self._channels[i]
|
||||||
|
|
||||||
|
async def add_channel(
|
||||||
|
self,
|
||||||
|
token: RBToken,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Add a new channel to be handled
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
await self._n.start(partial(
|
||||||
|
self._channel_handler_task,
|
||||||
|
RBToken.from_msg(token),
|
||||||
|
**kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
Remove a channel and stop its handling
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
maybe_entry = self._find_channel(name)
|
||||||
|
if not maybe_entry:
|
||||||
|
# return
|
||||||
|
raise RuntimeError(
|
||||||
|
f'tried to remove channel {name} but if does not exist'
|
||||||
|
)
|
||||||
|
|
||||||
|
i, info = maybe_entry
|
||||||
|
self._maybe_destroy_channel(name)
|
||||||
|
|
||||||
|
await info.teardown.wait()
|
||||||
|
|
||||||
|
# if that was last channel reset connect event
|
||||||
|
if len(self) == 0:
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
|
async def wait_for_channel(self):
|
||||||
|
'''
|
||||||
|
Wait until at least one channel added
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
await self._connect_event.wait()
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._channels)
|
||||||
|
|
||||||
|
def __getitem__(self, name: str):
|
||||||
|
maybe_entry = self._find_channel(name)
|
||||||
|
if maybe_entry:
|
||||||
|
_, info = maybe_entry
|
||||||
|
return info
|
||||||
|
|
||||||
|
raise KeyError(f'Channel {name} not found!')
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
self._is_closed = False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self.closed:
|
||||||
|
log.warning('tried to close ChannelManager but its already closed...')
|
||||||
|
return
|
||||||
|
|
||||||
|
for info in self._channels:
|
||||||
|
if info.channel.closed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await info.channel.aclose()
|
||||||
|
await self.remove_channel(info.token.shm_name)
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
|
||||||
|
'''
|
||||||
|
Use ChannelManager to create a multi ringbuf round robin sender that can
|
||||||
|
dynamically add or remove more outputs.
|
||||||
|
|
||||||
|
Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
|
||||||
|
lifecycle.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
# amount of msgs to each ring before switching turns
|
||||||
|
msgs_per_turn: int = 1,
|
||||||
|
|
||||||
|
# global batch size for all channels
|
||||||
|
batch_size: int = 1,
|
||||||
|
|
||||||
|
encoder: Encoder | None = None
|
||||||
|
):
|
||||||
|
self._batch_size: int = batch_size
|
||||||
|
self.msgs_per_turn = msgs_per_turn
|
||||||
|
self._enc = encoder
|
||||||
|
|
||||||
|
# helper to manage acms + long running tasks
|
||||||
|
self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
|
||||||
|
n,
|
||||||
|
self._open_channel,
|
||||||
|
self._channel_task
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure no concurrent `.send()` calls
|
||||||
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
# index of channel to be used for next send
|
||||||
|
self._next_turn: int = 0
|
||||||
|
# amount of messages sent this turn
|
||||||
|
self._turn_msgs: int = 0
|
||||||
|
# have we closed this publisher?
|
||||||
|
# set to `False` on `.__aenter__()`
|
||||||
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._is_closed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self) -> int:
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
@batch_size.setter
|
||||||
|
def batch_size(self, value: int) -> None:
|
||||||
|
for info in self.channels:
|
||||||
|
info.channel.batch_size = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._chanmngr.channels
|
||||||
|
|
||||||
|
def _get_next_turn(self) -> int:
|
||||||
|
'''
|
||||||
|
Maybe switch turn and reset self._turn_msgs or just increment it.
|
||||||
|
Return current turn
|
||||||
|
'''
|
||||||
|
if self._turn_msgs == self.msgs_per_turn:
|
||||||
|
self._turn_msgs = 0
|
||||||
|
self._next_turn += 1
|
||||||
|
|
||||||
|
if self._next_turn >= len(self.channels):
|
||||||
|
self._next_turn = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
self._turn_msgs += 1
|
||||||
|
|
||||||
|
return self._next_turn
|
||||||
|
|
||||||
|
def get_channel(self, name: str) -> ChannelInfo:
|
||||||
|
'''
|
||||||
|
Get underlying ChannelInfo from name
|
||||||
|
|
||||||
|
'''
|
||||||
|
return self._chanmngr[name]
|
||||||
|
|
||||||
|
async def add_channel(
|
||||||
|
self,
|
||||||
|
token: RBToken,
|
||||||
|
):
|
||||||
|
await self._chanmngr.add_channel(token)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str):
|
||||||
|
await self._chanmngr.remove_channel(name)
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def _open_channel(
|
||||||
|
|
||||||
|
self,
|
||||||
|
token: RBToken
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
|
||||||
|
async with attach_to_ringbuf_sender(
|
||||||
|
token,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
encoder=self._enc
|
||||||
|
) as ring:
|
||||||
|
yield ring
|
||||||
|
|
||||||
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
|
'''
|
||||||
|
Wait forever until channel cancellation
|
||||||
|
|
||||||
|
'''
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
async def send(self, msg: bytes):
|
||||||
|
'''
|
||||||
|
If no output channels connected, wait until one, then fetch the next
|
||||||
|
channel based on turn.
|
||||||
|
|
||||||
|
Needs to acquire `self._send_lock` to ensure no concurrent calls.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._send_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._send_lock:
|
||||||
|
# wait at least one decoder connected
|
||||||
|
if len(self.channels) == 0:
|
||||||
|
await self._chanmngr.wait_for_channel()
|
||||||
|
|
||||||
|
turn = self._get_next_turn()
|
||||||
|
|
||||||
|
info = self.channels[turn]
|
||||||
|
await info.channel.send(msg)
|
||||||
|
|
||||||
|
async def broadcast(self, msg: PayloadT):
|
||||||
|
'''
|
||||||
|
Send a msg to all channels, if no channels connected, does nothing.
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
for info in self.channels:
|
||||||
|
await info.channel.send(msg)
|
||||||
|
|
||||||
|
async def flush(self, new_batch_size: int | None = None):
|
||||||
|
for info in self.channels:
|
||||||
|
try:
|
||||||
|
await info.channel.flush(new_batch_size=new_batch_size)
|
||||||
|
|
||||||
|
except trio.ClosedResourceError:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._is_closed = False
|
||||||
|
self._chanmngr.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self.closed:
|
||||||
|
log.warning('tried to close RingBufferPublisher but its already closed...')
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._chanmngr.close()
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
|
|
||||||
|
class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
|
||||||
|
'''
|
||||||
|
Use ChannelManager to create a multi ringbuf receiver that can
|
||||||
|
dynamically add or remove more inputs and combine all into a single output.
|
||||||
|
|
||||||
|
In order for `self.receive` messages to be returned in order, publisher
|
||||||
|
will send all payloads as `OrderedPayload` msgpack encoded msgs, this
|
||||||
|
allows our channel handler tasks to just stash the out of order payloads
|
||||||
|
inside `self._pending_payloads` and if a in order payload is available
|
||||||
|
signal through `self._new_payload_event`.
|
||||||
|
|
||||||
|
On `self.receive` we wait until at least one channel is connected, then if
|
||||||
|
an in order payload is pending, we pop and return it, in case no in order
|
||||||
|
payload is available wait until next `self._new_payload_event.set()`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
decoder: Decoder | None = None
|
||||||
|
):
|
||||||
|
self._dec = decoder
|
||||||
|
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
|
||||||
|
n,
|
||||||
|
self._open_channel,
|
||||||
|
self._channel_task
|
||||||
|
)
|
||||||
|
|
||||||
|
self._schan, self._rchan = trio.open_memory_channel(0)
|
||||||
|
|
||||||
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
self._receive_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._is_closed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._chanmngr.channels
|
||||||
|
|
||||||
|
def get_channel(self, name: str):
|
||||||
|
return self._chanmngr[name]
|
||||||
|
|
||||||
|
async def add_channel(
|
||||||
|
self,
|
||||||
|
token: RBToken
|
||||||
|
):
|
||||||
|
await self._chanmngr.add_channel(token)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str):
|
||||||
|
await self._chanmngr.remove_channel(name)
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def _open_channel(
|
||||||
|
|
||||||
|
self,
|
||||||
|
token: RBToken
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||||
|
async with attach_to_ringbuf_receiver(
|
||||||
|
token,
|
||||||
|
decoder=self._dec
|
||||||
|
) as ring:
|
||||||
|
yield ring
|
||||||
|
|
||||||
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
|
'''
|
||||||
|
Iterate over receive channel messages, decode them as `OrderedPayload`s
|
||||||
|
and stash them in `self._pending_payloads`, in case we can pop next in
|
||||||
|
order payload, signal through setting `self._new_payload_event`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = await info.channel.receive()
|
||||||
|
await self._schan.send(msg)
|
||||||
|
|
||||||
|
except tractor.linux.eventfd.EFDReadCancelled as e:
|
||||||
|
# when channel gets removed while we are doing a receive
|
||||||
|
log.exception(e)
|
||||||
|
break
|
||||||
|
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
break
|
||||||
|
|
||||||
|
except trio.ClosedResourceError:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def receive(self) -> PayloadT:
|
||||||
|
'''
|
||||||
|
Receive next in order msg
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._receive_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._receive_lock:
|
||||||
|
return await self._rchan.receive()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._is_closed = False
|
||||||
|
self._chanmngr.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._chanmngr.close()
|
||||||
|
await self._schan.aclose()
|
||||||
|
await self._rchan.aclose()
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Actor module for managing publisher & subscriber channels remotely through
|
||||||
|
`tractor.context` rpc
|
||||||
|
'''
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PublisherEntry:
|
||||||
|
publisher: RingBufferPublisher | None = None
|
||||||
|
is_set: trio.Event = trio.Event()
|
||||||
|
|
||||||
|
|
||||||
|
_publishers: dict[str, PublisherEntry] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_init_publisher(topic: str) -> PublisherEntry:
|
||||||
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = PublisherEntry()
|
||||||
|
_publishers[topic] = entry
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def set_publisher(topic: str, pub: RingBufferPublisher):
|
||||||
|
global _publishers
|
||||||
|
|
||||||
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = maybe_init_publisher(topic)
|
||||||
|
|
||||||
|
if entry.publisher:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'publisher for topic {topic} already set on {tractor.current_actor()}'
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.publisher = pub
|
||||||
|
entry.is_set.set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_publisher(topic: str = 'default') -> RingBufferPublisher:
|
||||||
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry or not entry.publisher:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'{tractor.current_actor()} tried to get publisher'
|
||||||
|
'but it\'s not set'
|
||||||
|
)
|
||||||
|
|
||||||
|
return entry.publisher
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_publisher(topic: str) -> RingBufferPublisher:
|
||||||
|
entry = maybe_init_publisher(topic)
|
||||||
|
await entry.is_set.wait()
|
||||||
|
return entry.publisher
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def _add_pub_channel(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
|
token: RBToken
|
||||||
|
):
|
||||||
|
publisher = await wait_publisher(topic)
|
||||||
|
await publisher.add_channel(token)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def _remove_pub_channel(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
|
ring_name: str
|
||||||
|
):
|
||||||
|
publisher = await wait_publisher(topic)
|
||||||
|
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||||
|
if maybe_token:
|
||||||
|
await publisher.remove_channel(ring_name)
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_pub_channel_at(
|
||||||
|
actor_name: str,
|
||||||
|
token: RBToken,
|
||||||
|
topic: str = 'default',
|
||||||
|
):
|
||||||
|
async with tractor.find_actor(actor_name) as portal:
|
||||||
|
await portal.run(_add_pub_channel, topic=topic, token=token)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
log.warning(
|
||||||
|
'open_pub_channel_at got cancelled!\n'
|
||||||
|
f'\tactor_name = {actor_name}\n'
|
||||||
|
f'\ttoken = {token}\n'
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SubscriberEntry:
|
||||||
|
subscriber: RingBufferSubscriber | None = None
|
||||||
|
is_set: trio.Event = trio.Event()
|
||||||
|
|
||||||
|
|
||||||
|
_subscribers: dict[str, SubscriberEntry] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = SubscriberEntry()
|
||||||
|
_subscribers[topic] = entry
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def set_subscriber(topic: str, sub: RingBufferSubscriber):
|
||||||
|
global _subscribers
|
||||||
|
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = maybe_init_subscriber(topic)
|
||||||
|
|
||||||
|
if entry.subscriber:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.subscriber = sub
|
||||||
|
entry.is_set.set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_subscriber(topic: str = 'default') -> RingBufferSubscriber:
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry or not entry.subscriber:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'{tractor.current_actor()} tried to get subscriber'
|
||||||
|
'but it\'s not set'
|
||||||
|
)
|
||||||
|
|
||||||
|
return entry.subscriber
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
|
||||||
|
entry = maybe_init_subscriber(topic)
|
||||||
|
await entry.is_set.wait()
|
||||||
|
return entry.subscriber
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def _add_sub_channel(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
|
token: RBToken
|
||||||
|
):
|
||||||
|
subscriber = await wait_subscriber(topic)
|
||||||
|
await subscriber.add_channel(token)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def _remove_sub_channel(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
|
ring_name: str
|
||||||
|
):
|
||||||
|
subscriber = await wait_subscriber(topic)
|
||||||
|
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||||
|
if maybe_token:
|
||||||
|
await subscriber.remove_channel(ring_name)
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_sub_channel_at(
|
||||||
|
actor_name: str,
|
||||||
|
token: RBToken,
|
||||||
|
topic: str = 'default',
|
||||||
|
):
|
||||||
|
async with tractor.find_actor(actor_name) as portal:
|
||||||
|
await portal.run(_add_sub_channel, topic=topic, token=token)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
log.warning(
|
||||||
|
'open_sub_channel_at got cancelled!\n'
|
||||||
|
f'\tactor_name = {actor_name}\n'
|
||||||
|
f'\ttoken = {token}\n'
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
High level helpers to open publisher & subscriber
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_ringbuf_publisher(
|
||||||
|
# name to distinguish this publisher
|
||||||
|
topic: str = 'default',
|
||||||
|
|
||||||
|
# global batch size for channels
|
||||||
|
batch_size: int = 1,
|
||||||
|
|
||||||
|
# messages before changing output channel
|
||||||
|
msgs_per_turn: int = 1,
|
||||||
|
|
||||||
|
encoder: Encoder | None = None,
|
||||||
|
|
||||||
|
# ensure subscriber receives in same order publisher sent
|
||||||
|
# causes it to use wrapped payloads which contain the og
|
||||||
|
# index
|
||||||
|
guarantee_order: bool = False,
|
||||||
|
|
||||||
|
# on creation, set the `_publisher` global in order to use the provided
|
||||||
|
# tractor.context & helper utils for adding and removing new channels from
|
||||||
|
# remote actors
|
||||||
|
set_module_var: bool = True
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferPublisher]:
|
||||||
|
'''
|
||||||
|
Open a new ringbuf publisher
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
trio.open_nursery(strict_exception_groups=False) as n,
|
||||||
|
RingBufferPublisher(
|
||||||
|
n,
|
||||||
|
batch_size=batch_size,
|
||||||
|
encoder=encoder,
|
||||||
|
) as publisher
|
||||||
|
):
|
||||||
|
if guarantee_order:
|
||||||
|
order_send_channel(publisher)
|
||||||
|
|
||||||
|
if set_module_var:
|
||||||
|
set_publisher(topic, publisher)
|
||||||
|
|
||||||
|
yield publisher
|
||||||
|
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_ringbuf_subscriber(
|
||||||
|
# name to distinguish this subscriber
|
||||||
|
topic: str = 'default',
|
||||||
|
|
||||||
|
decoder: Decoder | None = None,
|
||||||
|
|
||||||
|
# expect indexed payloads and unwrap them in order
|
||||||
|
guarantee_order: bool = False,
|
||||||
|
|
||||||
|
# on creation, set the `_subscriber` global in order to use the provided
|
||||||
|
# tractor.context & helper utils for adding and removing new channels from
|
||||||
|
# remote actors
|
||||||
|
set_module_var: bool = True
|
||||||
|
) -> AsyncContextManager[RingBufferPublisher]:
|
||||||
|
'''
|
||||||
|
Open a new ringbuf subscriber
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
trio.open_nursery(strict_exception_groups=False) as n,
|
||||||
|
RingBufferSubscriber(n, decoder=decoder) as subscriber
|
||||||
|
):
|
||||||
|
# maybe monkey patch `.receive` to use indexed payloads
|
||||||
|
if guarantee_order:
|
||||||
|
order_receive_channel(subscriber)
|
||||||
|
|
||||||
|
# maybe set global module var for remote actor channel updates
|
||||||
|
if set_module_var:
|
||||||
|
set_subscriber(topic, subscriber)
|
||||||
|
|
||||||
|
yield subscriber
|
||||||
|
|
||||||
|
n.cancel_scope.cancel()
|
|
@ -78,7 +78,7 @@ class MsgTransport(Protocol):
|
||||||
# eventual msg definition/types?
|
# eventual msg definition/types?
|
||||||
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||||
|
|
||||||
stream: trio.SocketStream
|
stream: trio.abc.Stream
|
||||||
drained: list[MsgType]
|
drained: list[MsgType]
|
||||||
|
|
||||||
address_type: ClassVar[Type[Address]]
|
address_type: ClassVar[Type[Address]]
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
@ -0,0 +1,316 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
'''
|
||||||
|
Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
|
||||||
|
|
||||||
|
cpython impl:
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import array
|
||||||
|
import tempfile
|
||||||
|
from uuid import uuid4
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncContextManager
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import tractor
|
||||||
|
from trio import socket
|
||||||
|
|
||||||
|
|
||||||
|
log = tractor.log.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FDSharingError(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
|
||||||
|
'''
|
||||||
|
Async trio reimplementation of `multiprocessing.reduction.sendfds`
|
||||||
|
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
|
||||||
|
|
||||||
|
It's implemented using an async context manager in order to simplyfy usage
|
||||||
|
with `tractor.context`s, we can open a context in a remote actor that uses
|
||||||
|
this acm inside of it, and uses `ctx.started()` to signal the original
|
||||||
|
caller actor to perform the `recv_fds` call.
|
||||||
|
|
||||||
|
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||||
|
'''
|
||||||
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
await sock.bind(sock_path)
|
||||||
|
sock.listen(1)
|
||||||
|
|
||||||
|
yield # socket is setup, ready for receiver connect
|
||||||
|
|
||||||
|
# wait until receiver connects
|
||||||
|
conn, _ = await sock.accept()
|
||||||
|
|
||||||
|
# setup int array for fds
|
||||||
|
fds = array.array('i', fds)
|
||||||
|
|
||||||
|
# first byte of msg will be len of fds to send % 256, acting as a fd amount
|
||||||
|
# verification on `recv_fds` we refer to it as `check_byte`
|
||||||
|
msg = bytes([len(fds) % 256])
|
||||||
|
|
||||||
|
# send msg with custom SCM_RIGHTS type
|
||||||
|
await conn.sendmsg(
|
||||||
|
[msg],
|
||||||
|
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# finally wait receiver ack
|
||||||
|
if await conn.recv(1) != b'A':
|
||||||
|
raise FDSharingError('did not receive acknowledgement of fd')
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
sock.close()
|
||||||
|
os.unlink(sock_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def recv_fds(sock_path: str, amount: int) -> tuple:
|
||||||
|
'''
|
||||||
|
Async trio reimplementation of `multiprocessing.reduction.recvfds`
|
||||||
|
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
|
||||||
|
|
||||||
|
It's equivalent to std just using `trio.open_unix_socket` for connecting and
|
||||||
|
changes on error handling.
|
||||||
|
|
||||||
|
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||||
|
'''
|
||||||
|
stream = await trio.open_unix_socket(sock_path)
|
||||||
|
sock = stream.socket
|
||||||
|
|
||||||
|
# prepare int array for fds
|
||||||
|
a = array.array('i')
|
||||||
|
bytes_size = a.itemsize * amount
|
||||||
|
|
||||||
|
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
|
||||||
|
msg, ancdata, flags, addr = await sock.recvmsg(
|
||||||
|
1, socket.CMSG_SPACE(bytes_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# maybe failed to receive msg?
|
||||||
|
if not msg and not ancdata:
|
||||||
|
raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
|
||||||
|
|
||||||
|
# send ack, std comment mentions this ack pattern was to get around an
|
||||||
|
# old macosx bug, but they are not sure if its necesary any more, in
|
||||||
|
# any case its not a bad pattern to keep
|
||||||
|
await sock.send(b'A') # Ack
|
||||||
|
|
||||||
|
# expect to receive only one `ancdata` item
|
||||||
|
if len(ancdata) != 1:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# unpack SCM_RIGHTS msg
|
||||||
|
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
||||||
|
|
||||||
|
# check proper msg type
|
||||||
|
if cmsg_level != socket.SOL_SOCKET:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if cmsg_type != socket.SCM_RIGHTS:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# check proper data alignment
|
||||||
|
length = len(cmsg_data)
|
||||||
|
if length % a.itemsize != 0:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# attempt to cast as int array
|
||||||
|
a.frombytes(cmsg_data)
|
||||||
|
|
||||||
|
# validate length check byte
|
||||||
|
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
|
||||||
|
recvd_check_byte = msg[0] # actual received check byte
|
||||||
|
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
|
||||||
|
|
||||||
|
if recvd_check_byte != payload_check_byte:
|
||||||
|
raise FDSharingError(
|
||||||
|
'Validation failed: received check byte '
|
||||||
|
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_check_byte != recvd_check_byte:
|
||||||
|
raise FDSharingError(
|
||||||
|
'Validation failed: received check byte '
|
||||||
|
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(a)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Share FD actor module
|
||||||
|
|
||||||
|
Add "tractor.linux._fdshare" to enabled modules on actors to allow sharing of
|
||||||
|
FDs with other actors.
|
||||||
|
|
||||||
|
Use `share_fds` function to register a set of fds with a name, then other
|
||||||
|
actors can use `request_fds_from` function to retrieve the fds.
|
||||||
|
|
||||||
|
Use `unshare_fds` to disable sharing of a set of FDs.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
FDType = tuple[int]
|
||||||
|
|
||||||
|
_fds: dict[str, FDType] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_get_fds(name: str) -> FDType | None:
|
||||||
|
'''
|
||||||
|
Get registered FDs with a given name or return None
|
||||||
|
|
||||||
|
'''
|
||||||
|
return _fds.get(name, None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fds(name: str) -> FDType:
|
||||||
|
'''
|
||||||
|
Get registered FDs with a given name or raise
|
||||||
|
'''
|
||||||
|
fds = maybe_get_fds(name)
|
||||||
|
|
||||||
|
if not fds:
|
||||||
|
raise RuntimeError(f'No FDs with name {name} found!')
|
||||||
|
|
||||||
|
return fds
|
||||||
|
|
||||||
|
|
||||||
|
def share_fds(
|
||||||
|
name: str,
|
||||||
|
fds: tuple[int],
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Register a set of fds to be shared under a given name.
|
||||||
|
|
||||||
|
'''
|
||||||
|
this_actor = tractor.current_actor()
|
||||||
|
if __name__ not in this_actor.enable_modules:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Tried to share FDs {fds} with name {name}, but '
|
||||||
|
f'module {__name__} is not enabled in actor {this_actor.name}!'
|
||||||
|
)
|
||||||
|
|
||||||
|
maybe_fds = maybe_get_fds(name)
|
||||||
|
if maybe_fds:
|
||||||
|
raise RuntimeError(f'share FDs: {maybe_fds} already tied to name {name}')
|
||||||
|
|
||||||
|
_fds[name] = fds
|
||||||
|
|
||||||
|
|
||||||
|
def unshare_fds(name: str) -> None:
|
||||||
|
'''
|
||||||
|
Unregister a set of fds to disable sharing them.
|
||||||
|
|
||||||
|
'''
|
||||||
|
get_fds(name) # raise if not exists
|
||||||
|
|
||||||
|
del _fds[name]
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def _pass_fds(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
name: str,
|
||||||
|
sock_path: str
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Endpoint to request a set of FDs from current actor, will use `ctx.started`
|
||||||
|
to send original FDs, then `send_fds` will block until remote side finishes
|
||||||
|
the `recv_fds` call.
|
||||||
|
|
||||||
|
'''
|
||||||
|
# get fds or raise error
|
||||||
|
fds = get_fds(name)
|
||||||
|
|
||||||
|
# start fd passing context using socket on `sock_path`
|
||||||
|
async with send_fds(fds, sock_path):
|
||||||
|
# send original fds through ctx.started
|
||||||
|
await ctx.started(fds)
|
||||||
|
|
||||||
|
|
||||||
|
async def request_fds_from(
|
||||||
|
actor_name: str,
|
||||||
|
fds_name: str
|
||||||
|
) -> FDType:
|
||||||
|
'''
|
||||||
|
Use this function to retreive shared FDs from `actor_name`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
this_actor = tractor.current_actor()
|
||||||
|
|
||||||
|
# create a temporary path for the UDS sock
|
||||||
|
sock_path = str(
|
||||||
|
Path(tempfile.gettempdir())
|
||||||
|
/
|
||||||
|
f'{fds_name}-from-{actor_name}-to-{this_actor.name}.sock'
|
||||||
|
)
|
||||||
|
|
||||||
|
# having a socket path length > 100 aprox can cause:
|
||||||
|
# OSError: AF_UNIX path too long
|
||||||
|
# https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html#tag_13_67_04
|
||||||
|
|
||||||
|
# attempt sock path creation with smaller names
|
||||||
|
if len(sock_path) > 100:
|
||||||
|
sock_path = str(
|
||||||
|
Path(tempfile.gettempdir())
|
||||||
|
/
|
||||||
|
f'{fds_name}-to-{this_actor.name}.sock'
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(sock_path) > 100:
|
||||||
|
# just use uuid4
|
||||||
|
sock_path = str(
|
||||||
|
Path(tempfile.gettempdir())
|
||||||
|
/
|
||||||
|
f'pass-fds-{uuid4()}.sock'
|
||||||
|
)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
tractor.find_actor(actor_name) as portal,
|
||||||
|
|
||||||
|
portal.open_context(
|
||||||
|
_pass_fds,
|
||||||
|
name=fds_name,
|
||||||
|
sock_path=sock_path
|
||||||
|
) as (ctx, fds_info),
|
||||||
|
):
|
||||||
|
# get original FDs
|
||||||
|
og_fds = fds_info
|
||||||
|
|
||||||
|
# retrieve copies of FDs
|
||||||
|
fds = await recv_fds(sock_path, len(og_fds))
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f'{this_actor.name} received fds: {og_fds} -> {fds}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return fds
|
|
@ -14,7 +14,7 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
'''
|
||||||
Linux specifics, for now we are only exposing EventFD
|
Expose libc eventfd APIs
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import os
|
import os
|
||||||
|
@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
|
||||||
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
||||||
|
|
||||||
|
|
||||||
|
class EFDReadCancelled(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class EventFD:
|
class EventFD:
|
||||||
'''
|
'''
|
||||||
Use a previously opened eventfd(2), meant to be used in
|
Use a previously opened eventfd(2), meant to be used in
|
||||||
|
@ -124,26 +128,82 @@ class EventFD:
|
||||||
self._fd: int = fd
|
self._fd: int = fd
|
||||||
self._omode: str = omode
|
self._omode: str = omode
|
||||||
self._fobj = None
|
self._fobj = None
|
||||||
|
self._cscope: trio.CancelScope | None = None
|
||||||
|
self._is_closed: bool = True
|
||||||
|
self._read_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._is_closed
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fd(self) -> int | None:
|
def fd(self) -> int | None:
|
||||||
return self._fd
|
return self._fd
|
||||||
|
|
||||||
def write(self, value: int) -> int:
|
def write(self, value: int) -> int:
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
return write_eventfd(self._fd, value)
|
return write_eventfd(self._fd, value)
|
||||||
|
|
||||||
async def read(self) -> int:
|
async def read(self) -> int:
|
||||||
return await trio.to_thread.run_sync(
|
'''
|
||||||
read_eventfd, self._fd,
|
Async wrapper for `read_eventfd(self.fd)`
|
||||||
abandon_on_cancel=True
|
|
||||||
)
|
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
|
||||||
|
in order to make it cancellable when `self.close()` is called.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._read_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._read_lock:
|
||||||
|
self._cscope = trio.CancelScope()
|
||||||
|
with self._cscope:
|
||||||
|
try:
|
||||||
|
return await trio.to_thread.run_sync(
|
||||||
|
read_eventfd, self._fd,
|
||||||
|
abandon_on_cancel=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EBADF:
|
||||||
|
raise
|
||||||
|
|
||||||
|
raise trio.BrokenResourceError
|
||||||
|
|
||||||
|
if self._cscope.cancelled_caught:
|
||||||
|
raise EFDReadCancelled
|
||||||
|
|
||||||
|
self._cscope = None
|
||||||
|
|
||||||
|
def read_nowait(self) -> int:
|
||||||
|
'''
|
||||||
|
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
|
||||||
|
opened with `EFD_NONBLOCK` its gonna block the thread.
|
||||||
|
|
||||||
|
'''
|
||||||
|
return read_eventfd(self._fd)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self._fobj = os.fdopen(self._fd, self._omode)
|
self._fobj = os.fdopen(self._fd, self._omode)
|
||||||
|
self._is_closed = False
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self._fobj:
|
if self._fobj:
|
||||||
self._fobj.close()
|
try:
|
||||||
|
self._fobj.close()
|
||||||
|
|
||||||
|
except OSError:
|
||||||
|
...
|
||||||
|
|
||||||
|
if self._cscope:
|
||||||
|
self._cscope.cancel()
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.open()
|
self.open()
|
|
@ -32,3 +32,8 @@ from ._broadcast import (
|
||||||
from ._beg import (
|
from ._beg import (
|
||||||
collapse_eg as collapse_eg,
|
collapse_eg as collapse_eg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._ordering import (
|
||||||
|
order_send_channel as order_send_channel,
|
||||||
|
order_receive_channel as order_receive_channel
|
||||||
|
)
|
||||||
|
|
|
@ -70,7 +70,8 @@ async def maybe_open_nursery(
|
||||||
yield nursery
|
yield nursery
|
||||||
else:
|
else:
|
||||||
async with lib.open_nursery(**kwargs) as nursery:
|
async with lib.open_nursery(**kwargs) as nursery:
|
||||||
nursery.cancel_scope.shield = shield
|
if lib == trio:
|
||||||
|
nursery.cancel_scope.shield = shield
|
||||||
yield nursery
|
yield nursery
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
'''
|
||||||
|
Helpers to guarantee ordering of messages through a unordered channel
|
||||||
|
|
||||||
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
|
from heapq import (
|
||||||
|
heappush,
|
||||||
|
heappop
|
||||||
|
)
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
|
||||||
|
class OrderedPayload(msgspec.Struct, frozen=True):
|
||||||
|
index: int
|
||||||
|
payload: bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_msg(cls, msg: bytes) -> OrderedPayload:
|
||||||
|
return msgspec.msgpack.decode(msg, type=OrderedPayload)
|
||||||
|
|
||||||
|
def encode(self) -> bytes:
|
||||||
|
return msgspec.msgpack.encode(self)
|
||||||
|
|
||||||
|
|
||||||
|
def order_send_channel(
|
||||||
|
channel: trio.abc.SendChannel[bytes],
|
||||||
|
start_index: int = 0
|
||||||
|
):
|
||||||
|
|
||||||
|
next_index = start_index
|
||||||
|
send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
channel._send = channel.send
|
||||||
|
channel._aclose = channel.aclose
|
||||||
|
|
||||||
|
async def send(msg: bytes):
|
||||||
|
nonlocal next_index
|
||||||
|
async with send_lock:
|
||||||
|
await channel._send(
|
||||||
|
OrderedPayload(
|
||||||
|
index=next_index,
|
||||||
|
payload=msg
|
||||||
|
).encode()
|
||||||
|
)
|
||||||
|
next_index += 1
|
||||||
|
|
||||||
|
async def aclose():
|
||||||
|
async with send_lock:
|
||||||
|
await channel._aclose()
|
||||||
|
|
||||||
|
channel.send = send
|
||||||
|
channel.aclose = aclose
|
||||||
|
|
||||||
|
|
||||||
|
def order_receive_channel(
|
||||||
|
channel: trio.abc.ReceiveChannel[bytes],
|
||||||
|
start_index: int = 0
|
||||||
|
):
|
||||||
|
next_index = start_index
|
||||||
|
pqueue = []
|
||||||
|
|
||||||
|
channel._receive = channel.receive
|
||||||
|
|
||||||
|
def can_pop_next() -> bool:
|
||||||
|
return (
|
||||||
|
len(pqueue) > 0
|
||||||
|
and
|
||||||
|
pqueue[0][0] == next_index
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drain_to_heap():
|
||||||
|
while not can_pop_next():
|
||||||
|
msg = await channel._receive()
|
||||||
|
msg = OrderedPayload.from_msg(msg)
|
||||||
|
heappush(pqueue, (msg.index, msg.payload))
|
||||||
|
|
||||||
|
def pop_next():
|
||||||
|
nonlocal next_index
|
||||||
|
_, msg = heappop(pqueue)
|
||||||
|
next_index += 1
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def receive() -> bytes:
|
||||||
|
if can_pop_next():
|
||||||
|
return pop_next()
|
||||||
|
|
||||||
|
await drain_to_heap()
|
||||||
|
|
||||||
|
return pop_next()
|
||||||
|
|
||||||
|
channel.receive = receive
|
Loading…
Reference in New Issue