Compare commits
11 Commits
1762b3eb64
...
efd11f7d74
Author | SHA1 | Date |
---|---|---|
|
efd11f7d74 | |
|
76cee99fc2 | |
|
5f50206d84 | |
|
a47a7a39b1 | |
|
bab265b2d8 | |
|
010874bed5 | |
|
ea010ab46a | |
|
be7fc89ae9 | |
|
2a9a78651b | |
|
be818a720a | |
|
ba353bf46f |
|
@ -14,6 +14,6 @@ pkgs.mkShell {
|
||||||
|
|
||||||
shellHook = ''
|
shellHook = ''
|
||||||
set -e
|
set -e
|
||||||
uv venv .venv --python=3.12
|
uv venv .venv --python=3.11
|
||||||
'';
|
'';
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import trio
|
||||||
|
import pytest
|
||||||
|
from tractor.ipc import (
|
||||||
|
open_eventfd,
|
||||||
|
EFDReadCancelled,
|
||||||
|
EventFD
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eventfd_read_cancellation():
|
||||||
|
'''
|
||||||
|
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
|
||||||
|
is called.
|
||||||
|
|
||||||
|
'''
|
||||||
|
fd = open_eventfd()
|
||||||
|
|
||||||
|
async def _read(event: EventFD):
|
||||||
|
with pytest.raises(EFDReadCancelled):
|
||||||
|
await event.read()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
with (
|
||||||
|
EventFD(fd, 'w') as event,
|
||||||
|
trio.fail_after(3)
|
||||||
|
):
|
||||||
|
n.start_soon(_read, event)
|
||||||
|
await trio.sleep(0.2)
|
||||||
|
event.close()
|
||||||
|
|
||||||
|
trio.run(main)
|
|
@ -1,15 +1,21 @@
|
||||||
import time
|
import time
|
||||||
|
import hashlib
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pytest
|
import pytest
|
||||||
import tractor
|
import tractor
|
||||||
from tractor.ipc import (
|
from tractor.ipc import (
|
||||||
open_ringbuf,
|
open_ringbuf,
|
||||||
|
attach_to_ringbuf_receiver,
|
||||||
|
attach_to_ringbuf_sender,
|
||||||
|
attach_to_ringbuf_stream,
|
||||||
|
attach_to_ringbuf_channel,
|
||||||
RBToken,
|
RBToken,
|
||||||
RingBuffSender,
|
|
||||||
RingBuffReceiver
|
|
||||||
)
|
)
|
||||||
from tractor._testing.samples import generate_sample_messages
|
from tractor._testing.samples import (
|
||||||
|
generate_single_byte_msgs,
|
||||||
|
generate_sample_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
|
@ -17,19 +23,27 @@ async def child_read_shm(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
msg_amount: int,
|
msg_amount: int,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
total_bytes: int,
|
) -> str:
|
||||||
) -> None:
|
'''
|
||||||
recvd_bytes = 0
|
Sub-actor used in `test_ringbuf`.
|
||||||
await ctx.started()
|
|
||||||
start_ts = time.time()
|
|
||||||
async with RingBuffReceiver(token) as receiver:
|
|
||||||
while recvd_bytes < total_bytes:
|
|
||||||
msg = await receiver.receive_some()
|
|
||||||
recvd_bytes += len(msg)
|
|
||||||
|
|
||||||
# make sure we dont hold any memoryviews
|
Attach to a ringbuf and receive all messages until end of stream.
|
||||||
# before the ctx manager aclose()
|
Keep track of how many bytes received and also calculate
|
||||||
msg = None
|
sha256 of the whole byte stream.
|
||||||
|
|
||||||
|
Calculate and print performance stats, finally return calculated
|
||||||
|
hash.
|
||||||
|
|
||||||
|
'''
|
||||||
|
await ctx.started()
|
||||||
|
print('reader started')
|
||||||
|
recvd_bytes = 0
|
||||||
|
recvd_hash = hashlib.sha256()
|
||||||
|
start_ts = time.time()
|
||||||
|
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||||
|
async for msg in receiver:
|
||||||
|
recvd_hash.update(msg)
|
||||||
|
recvd_bytes += len(msg)
|
||||||
|
|
||||||
end_ts = time.time()
|
end_ts = time.time()
|
||||||
elapsed = end_ts - start_ts
|
elapsed = end_ts - start_ts
|
||||||
|
@ -38,6 +52,9 @@ async def child_read_shm(
|
||||||
print(f'\n\telapsed ms: {elapsed_ms}')
|
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||||
|
print(f'\treceived bytes: {recvd_bytes:,}')
|
||||||
|
|
||||||
|
return recvd_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
|
@ -48,16 +65,32 @@ async def child_write_shm(
|
||||||
rand_max: int,
|
rand_max: int,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
) -> None:
|
) -> None:
|
||||||
msgs, total_bytes = generate_sample_messages(
|
'''
|
||||||
|
Sub-actor used in `test_ringbuf`
|
||||||
|
|
||||||
|
Generate `msg_amount` payloads with
|
||||||
|
`random.randint(rand_min, rand_max)` random bytes at the end,
|
||||||
|
Calculate sha256 hash and send it to parent on `ctx.started`.
|
||||||
|
|
||||||
|
Attach to ringbuf and send all generated messages.
|
||||||
|
|
||||||
|
'''
|
||||||
|
msgs, _total_bytes = generate_sample_messages(
|
||||||
msg_amount,
|
msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
)
|
)
|
||||||
await ctx.started(total_bytes)
|
print('writer hashing payload...')
|
||||||
async with RingBuffSender(token) as sender:
|
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
|
||||||
|
print('writer done hashing.')
|
||||||
|
await ctx.started(sent_hash)
|
||||||
|
print('writer started')
|
||||||
|
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
await sender.send_all(msg)
|
await sender.send_all(msg)
|
||||||
|
|
||||||
|
print('writer exit')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'msg_amount,rand_min,rand_max,buf_size',
|
'msg_amount,rand_min,rand_max,buf_size',
|
||||||
|
@ -83,19 +116,23 @@ def test_ringbuf(
|
||||||
rand_max: int,
|
rand_max: int,
|
||||||
buf_size: int
|
buf_size: int
|
||||||
):
|
):
|
||||||
|
'''
|
||||||
|
- Open a new ring buf on root actor
|
||||||
|
- Open `child_write_shm` ctx in sub-actor which will generate a
|
||||||
|
random payload and send its hash on `ctx.started`, finally sending
|
||||||
|
the payload through the stream.
|
||||||
|
- Open `child_read_shm` ctx in sub-actor which will receive the
|
||||||
|
payload, calculate perf stats and return the hash.
|
||||||
|
- Compare both hashes
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf(
|
with open_ringbuf(
|
||||||
'test_ringbuf',
|
'test_ringbuf',
|
||||||
buf_size=buf_size
|
buf_size=buf_size
|
||||||
) as token:
|
) as token:
|
||||||
proc_kwargs = {
|
proc_kwargs = {'pass_fds': token.fds}
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
|
||||||
}
|
|
||||||
|
|
||||||
common_kwargs = {
|
|
||||||
'msg_amount': msg_amount,
|
|
||||||
'token': token,
|
|
||||||
}
|
|
||||||
async with tractor.open_nursery() as an:
|
async with tractor.open_nursery() as an:
|
||||||
send_p = await an.start_actor(
|
send_p = await an.start_actor(
|
||||||
'ring_sender',
|
'ring_sender',
|
||||||
|
@ -110,17 +147,20 @@ def test_ringbuf(
|
||||||
async with (
|
async with (
|
||||||
send_p.open_context(
|
send_p.open_context(
|
||||||
child_write_shm,
|
child_write_shm,
|
||||||
|
token=token,
|
||||||
|
msg_amount=msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
**common_kwargs
|
) as (_sctx, sent_hash),
|
||||||
) as (sctx, total_bytes),
|
|
||||||
recv_p.open_context(
|
recv_p.open_context(
|
||||||
child_read_shm,
|
child_read_shm,
|
||||||
**common_kwargs,
|
token=token,
|
||||||
total_bytes=total_bytes,
|
msg_amount=msg_amount
|
||||||
) as (sctx, _sent),
|
) as (rctx, _sent),
|
||||||
):
|
):
|
||||||
await recv_p.result()
|
recvd_hash = await rctx.result()
|
||||||
|
|
||||||
|
assert sent_hash == recvd_hash
|
||||||
|
|
||||||
await send_p.cancel_actor()
|
await send_p.cancel_actor()
|
||||||
await recv_p.cancel_actor()
|
await recv_p.cancel_actor()
|
||||||
|
@ -134,23 +174,28 @@ async def child_blocked_receiver(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
token: RBToken
|
token: RBToken
|
||||||
):
|
):
|
||||||
async with RingBuffReceiver(token) as receiver:
|
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||||
await ctx.started()
|
await ctx.started()
|
||||||
await receiver.receive_some()
|
await receiver.receive_some()
|
||||||
|
|
||||||
|
|
||||||
def test_ring_reader_cancel():
|
def test_reader_cancel():
|
||||||
|
'''
|
||||||
|
Test that a receiver blocked on eventfd(2) read responds to
|
||||||
|
cancellation.
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf('test_ring_cancel_reader') as token:
|
with open_ringbuf('test_ring_cancel_reader') as token:
|
||||||
async with (
|
async with (
|
||||||
tractor.open_nursery() as an,
|
tractor.open_nursery() as an,
|
||||||
RingBuffSender(token) as _sender,
|
attach_to_ringbuf_sender(token) as _sender,
|
||||||
):
|
):
|
||||||
recv_p = await an.start_actor(
|
recv_p = await an.start_actor(
|
||||||
'ring_blocked_receiver',
|
'ring_blocked_receiver',
|
||||||
enable_modules=[__name__],
|
enable_modules=[__name__],
|
||||||
proc_kwargs={
|
proc_kwargs={
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
'pass_fds': token.fds
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async with (
|
async with (
|
||||||
|
@ -172,12 +217,17 @@ async def child_blocked_sender(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
token: RBToken
|
token: RBToken
|
||||||
):
|
):
|
||||||
async with RingBuffSender(token) as sender:
|
async with attach_to_ringbuf_sender(token) as sender:
|
||||||
await ctx.started()
|
await ctx.started()
|
||||||
await sender.send_all(b'this will wrap')
|
await sender.send_all(b'this will wrap')
|
||||||
|
|
||||||
|
|
||||||
def test_ring_sender_cancel():
|
def test_sender_cancel():
|
||||||
|
'''
|
||||||
|
Test that a sender blocked on eventfd(2) read responds to
|
||||||
|
cancellation.
|
||||||
|
|
||||||
|
'''
|
||||||
async def main():
|
async def main():
|
||||||
with open_ringbuf(
|
with open_ringbuf(
|
||||||
'test_ring_cancel_sender',
|
'test_ring_cancel_sender',
|
||||||
|
@ -188,7 +238,7 @@ def test_ring_sender_cancel():
|
||||||
'ring_blocked_sender',
|
'ring_blocked_sender',
|
||||||
enable_modules=[__name__],
|
enable_modules=[__name__],
|
||||||
proc_kwargs={
|
proc_kwargs={
|
||||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
'pass_fds': token.fds
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async with (
|
async with (
|
||||||
|
@ -203,3 +253,171 @@ def test_ring_sender_cancel():
|
||||||
|
|
||||||
with pytest.raises(tractor._exceptions.ContextCancelled):
|
with pytest.raises(tractor._exceptions.ContextCancelled):
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_receiver_max_bytes():
|
||||||
|
'''
|
||||||
|
Test that RingBuffReceiver.receive_some's max_bytes optional
|
||||||
|
argument works correctly, send a msg of size 100, then
|
||||||
|
force receive of messages with max_bytes == 1, wait until
|
||||||
|
100 of these messages are received, then compare join of
|
||||||
|
msgs with original message
|
||||||
|
|
||||||
|
'''
|
||||||
|
msg = generate_single_byte_msgs(100)
|
||||||
|
msgs = []
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
with open_ringbuf(
|
||||||
|
'test_ringbuf_max_bytes',
|
||||||
|
buf_size=10
|
||||||
|
) as token:
|
||||||
|
async with (
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
attach_to_ringbuf_sender(token, cleanup=False) as sender,
|
||||||
|
attach_to_ringbuf_receiver(token, cleanup=False) as receiver
|
||||||
|
):
|
||||||
|
async def _send_and_close():
|
||||||
|
await sender.send_all(msg)
|
||||||
|
await sender.aclose()
|
||||||
|
|
||||||
|
n.start_soon(_send_and_close)
|
||||||
|
while len(msgs) < len(msg):
|
||||||
|
msg_part = await receiver.receive_some(max_bytes=1)
|
||||||
|
assert len(msg_part) == 1
|
||||||
|
msgs.append(msg_part)
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
assert msg == b''.join(msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stapled_ringbuf():
|
||||||
|
'''
|
||||||
|
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
|
||||||
|
are inversed on each task) which will open the streams and use trio.StapledStream
|
||||||
|
to have a single bidirectional stream.
|
||||||
|
|
||||||
|
Then take turns to send and receive messages.
|
||||||
|
|
||||||
|
'''
|
||||||
|
msg = generate_single_byte_msgs(100)
|
||||||
|
pair_0_msgs = []
|
||||||
|
pair_1_msgs = []
|
||||||
|
|
||||||
|
pair_0_done = trio.Event()
|
||||||
|
pair_1_done = trio.Event()
|
||||||
|
|
||||||
|
async def pair_0(token_in: RBToken, token_out: RBToken):
|
||||||
|
async with attach_to_ringbuf_stream(
|
||||||
|
token_in,
|
||||||
|
token_out,
|
||||||
|
cleanup_in=False,
|
||||||
|
cleanup_out=False
|
||||||
|
) as stream:
|
||||||
|
# first turn to send
|
||||||
|
await stream.send_all(msg)
|
||||||
|
|
||||||
|
# second turn to receive
|
||||||
|
while len(pair_0_msgs) != len(msg):
|
||||||
|
_msg = await stream.receive_some(max_bytes=1)
|
||||||
|
pair_0_msgs.append(_msg)
|
||||||
|
|
||||||
|
pair_0_done.set()
|
||||||
|
await pair_1_done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def pair_1(token_in: RBToken, token_out: RBToken):
|
||||||
|
async with attach_to_ringbuf_stream(
|
||||||
|
token_in,
|
||||||
|
token_out,
|
||||||
|
cleanup_in=False,
|
||||||
|
cleanup_out=False
|
||||||
|
) as stream:
|
||||||
|
# first turn to receive
|
||||||
|
while len(pair_1_msgs) != len(msg):
|
||||||
|
_msg = await stream.receive_some(max_bytes=1)
|
||||||
|
pair_1_msgs.append(_msg)
|
||||||
|
|
||||||
|
# second turn to send
|
||||||
|
await stream.send_all(msg)
|
||||||
|
|
||||||
|
pair_1_done.set()
|
||||||
|
await pair_0_done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
with tractor.ipc.open_ringbuf_pair(
|
||||||
|
'test_stapled_ringbuf'
|
||||||
|
) as (token_0, token_1):
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(pair_0, token_0, token_1)
|
||||||
|
n.start_soon(pair_1, token_1, token_0)
|
||||||
|
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
assert msg == b''.join(pair_0_msgs)
|
||||||
|
assert msg == b''.join(pair_1_msgs)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def child_channel_sender(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
msg_amount_min: int,
|
||||||
|
msg_amount_max: int,
|
||||||
|
token_in: RBToken,
|
||||||
|
token_out: RBToken
|
||||||
|
):
|
||||||
|
import random
|
||||||
|
msgs, _total_bytes = generate_sample_messages(
|
||||||
|
random.randint(msg_amount_min, msg_amount_max),
|
||||||
|
rand_min=256,
|
||||||
|
rand_max=1024,
|
||||||
|
)
|
||||||
|
async with attach_to_ringbuf_channel(
|
||||||
|
token_in,
|
||||||
|
token_out
|
||||||
|
) as chan:
|
||||||
|
await ctx.started(msgs)
|
||||||
|
|
||||||
|
for msg in msgs:
|
||||||
|
await chan.send(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel():
|
||||||
|
|
||||||
|
msg_amount_min = 100
|
||||||
|
msg_amount_max = 1000
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
with tractor.ipc.open_ringbuf_pair(
|
||||||
|
'test_ringbuf_transport'
|
||||||
|
) as (token_0, token_1):
|
||||||
|
async with (
|
||||||
|
attach_to_ringbuf_channel(token_0, token_1) as chan,
|
||||||
|
tractor.open_nursery() as an
|
||||||
|
):
|
||||||
|
recv_p = await an.start_actor(
|
||||||
|
'test_ringbuf_transport_sender',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
proc_kwargs={
|
||||||
|
'pass_fds': token_0.fds + token_1.fds
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async with (
|
||||||
|
recv_p.open_context(
|
||||||
|
child_channel_sender,
|
||||||
|
msg_amount_min=msg_amount_min,
|
||||||
|
msg_amount_max=msg_amount_max,
|
||||||
|
token_in=token_1,
|
||||||
|
token_out=token_0
|
||||||
|
) as (ctx, msgs),
|
||||||
|
):
|
||||||
|
recv_msgs = []
|
||||||
|
async for msg in chan:
|
||||||
|
recv_msgs.append(msg)
|
||||||
|
|
||||||
|
await recv_p.cancel_actor()
|
||||||
|
assert recv_msgs == msgs
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
|
@ -2,19 +2,59 @@ import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||||
|
'''
|
||||||
|
Generate a byte instance of len `amount` with:
|
||||||
|
|
||||||
|
```
|
||||||
|
byte_at_index(i) = (i % 10).encode()
|
||||||
|
```
|
||||||
|
|
||||||
|
this results in constantly repeating sequences of:
|
||||||
|
|
||||||
|
b'0123456789'
|
||||||
|
|
||||||
|
'''
|
||||||
|
return b''.join(str(i % 10).encode() for i in range(amount))
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_messages(
|
def generate_sample_messages(
|
||||||
amount: int,
|
amount: int,
|
||||||
rand_min: int = 0,
|
rand_min: int = 0,
|
||||||
rand_max: int = 0,
|
rand_max: int = 0,
|
||||||
silent: bool = False
|
silent: bool = False,
|
||||||
) -> tuple[list[bytes], int]:
|
) -> tuple[list[bytes], int]:
|
||||||
|
'''
|
||||||
|
Generate bytes msgs for tests.
|
||||||
|
|
||||||
|
Messages will have the following format:
|
||||||
|
|
||||||
|
```
|
||||||
|
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
|
||||||
|
```
|
||||||
|
|
||||||
|
so for message index 25:
|
||||||
|
|
||||||
|
b'[00000025]' + random_bytes
|
||||||
|
|
||||||
|
'''
|
||||||
msgs = []
|
msgs = []
|
||||||
size = 0
|
size = 0
|
||||||
|
|
||||||
|
log_interval = None
|
||||||
if not silent:
|
if not silent:
|
||||||
print(f'\ngenerating {amount} messages...')
|
print(f'\ngenerating {amount} messages...')
|
||||||
|
|
||||||
|
# calculate an apropiate log interval based on
|
||||||
|
# max message size
|
||||||
|
max_msg_size = 10 + rand_max
|
||||||
|
|
||||||
|
if max_msg_size <= 32 * 1024:
|
||||||
|
log_interval = 10_000
|
||||||
|
|
||||||
|
else:
|
||||||
|
log_interval = 1000
|
||||||
|
|
||||||
for i in range(amount):
|
for i in range(amount):
|
||||||
msg = f'[{i:08}]'.encode('utf-8')
|
msg = f'[{i:08}]'.encode('utf-8')
|
||||||
|
|
||||||
|
@ -26,7 +66,13 @@ def generate_sample_messages(
|
||||||
|
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
|
|
||||||
if not silent and i and i % 10_000 == 0:
|
if (
|
||||||
|
not silent
|
||||||
|
and
|
||||||
|
i > 0
|
||||||
|
and
|
||||||
|
i % log_interval == 0
|
||||||
|
):
|
||||||
print(f'{i} generated')
|
print(f'{i} generated')
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
|
|
|
@ -44,12 +44,23 @@ if platform.system() == 'Linux':
|
||||||
write_eventfd as write_eventfd,
|
write_eventfd as write_eventfd,
|
||||||
read_eventfd as read_eventfd,
|
read_eventfd as read_eventfd,
|
||||||
close_eventfd as close_eventfd,
|
close_eventfd as close_eventfd,
|
||||||
|
EFDReadCancelled as EFDReadCancelled,
|
||||||
EventFD as EventFD,
|
EventFD as EventFD,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._ringbuf import (
|
from ._ringbuf import (
|
||||||
RBToken as RBToken,
|
RBToken as RBToken,
|
||||||
|
open_ringbuf as open_ringbuf,
|
||||||
RingBuffSender as RingBuffSender,
|
RingBuffSender as RingBuffSender,
|
||||||
RingBuffReceiver as RingBuffReceiver,
|
RingBuffReceiver as RingBuffReceiver,
|
||||||
open_ringbuf as open_ringbuf
|
open_ringbuf_pair as open_ringbuf_pair,
|
||||||
|
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
|
||||||
|
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
|
||||||
|
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
|
||||||
|
RingBuffBytesSender as RingBuffBytesSender,
|
||||||
|
RingBuffBytesReceiver as RingBuffBytesReceiver,
|
||||||
|
RingBuffChannel as RingBuffChannel,
|
||||||
|
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel,
|
||||||
|
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
|
||||||
|
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
|
||||||
)
|
)
|
||||||
|
|
|
@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
|
||||||
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
||||||
|
|
||||||
|
|
||||||
|
class EFDReadCancelled(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class EventFD:
|
class EventFD:
|
||||||
'''
|
'''
|
||||||
Use a previously opened eventfd(2), meant to be used in
|
Use a previously opened eventfd(2), meant to be used in
|
||||||
|
@ -124,6 +128,7 @@ class EventFD:
|
||||||
self._fd: int = fd
|
self._fd: int = fd
|
||||||
self._omode: str = omode
|
self._omode: str = omode
|
||||||
self._fobj = None
|
self._fobj = None
|
||||||
|
self._cscope: trio.CancelScope | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fd(self) -> int | None:
|
def fd(self) -> int | None:
|
||||||
|
@ -133,17 +138,46 @@ class EventFD:
|
||||||
return write_eventfd(self._fd, value)
|
return write_eventfd(self._fd, value)
|
||||||
|
|
||||||
async def read(self) -> int:
|
async def read(self) -> int:
|
||||||
return await trio.to_thread.run_sync(
|
'''
|
||||||
read_eventfd, self._fd,
|
Async wrapper for `read_eventfd(self.fd)`
|
||||||
abandon_on_cancel=True
|
|
||||||
)
|
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
|
||||||
|
in order to make it cancellable when `self.close()` is called.
|
||||||
|
|
||||||
|
'''
|
||||||
|
self._cscope = trio.CancelScope()
|
||||||
|
with self._cscope:
|
||||||
|
return await trio.to_thread.run_sync(
|
||||||
|
read_eventfd, self._fd,
|
||||||
|
abandon_on_cancel=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._cscope.cancelled_caught:
|
||||||
|
raise EFDReadCancelled
|
||||||
|
|
||||||
|
self._cscope = None
|
||||||
|
|
||||||
|
def read_direct(self) -> int:
|
||||||
|
'''
|
||||||
|
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
|
||||||
|
opened with `EFD_NONBLOCK` its gonna block the thread.
|
||||||
|
|
||||||
|
'''
|
||||||
|
return read_eventfd(self._fd)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self._fobj = os.fdopen(self._fd, self._omode)
|
self._fobj = os.fdopen(self._fd, self._omode)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self._fobj:
|
if self._fobj:
|
||||||
self._fobj.close()
|
try:
|
||||||
|
self._fobj.close()
|
||||||
|
|
||||||
|
except OSError:
|
||||||
|
...
|
||||||
|
|
||||||
|
if self._cscope:
|
||||||
|
self._cscope.cancel()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.open()
|
self.open()
|
||||||
|
|
|
@ -18,7 +18,15 @@ IPC Reliable RingBuffer implementation
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from contextlib import contextmanager as cm
|
import struct
|
||||||
|
from typing import (
|
||||||
|
ContextManager,
|
||||||
|
AsyncContextManager
|
||||||
|
)
|
||||||
|
from contextlib import (
|
||||||
|
contextmanager as cm,
|
||||||
|
asynccontextmanager as acm
|
||||||
|
)
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -28,25 +36,37 @@ from msgspec import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._linux import (
|
from ._linux import (
|
||||||
EFD_NONBLOCK,
|
|
||||||
open_eventfd,
|
open_eventfd,
|
||||||
|
EFDReadCancelled,
|
||||||
EventFD
|
EventFD
|
||||||
)
|
)
|
||||||
from ._mp_bs import disable_mantracker
|
from ._mp_bs import disable_mantracker
|
||||||
|
from tractor.log import get_logger
|
||||||
|
from tractor._exceptions import (
|
||||||
|
InternalError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
disable_mantracker()
|
disable_mantracker()
|
||||||
|
|
||||||
|
_DEFAULT_RB_SIZE = 10 * 1024
|
||||||
|
|
||||||
|
|
||||||
class RBToken(Struct, frozen=True):
|
class RBToken(Struct, frozen=True):
|
||||||
'''
|
'''
|
||||||
RingBuffer token contains necesary info to open the two
|
RingBuffer token contains necesary info to open the three
|
||||||
eventfds and the shared memory
|
eventfds and the shared memory
|
||||||
|
|
||||||
'''
|
'''
|
||||||
shm_name: str
|
shm_name: str
|
||||||
write_eventfd: int
|
|
||||||
wrap_eventfd: int
|
write_eventfd: int # used to signal writer ptr advance
|
||||||
|
wrap_eventfd: int # used to signal reader ready after wrap around
|
||||||
|
eof_eventfd: int # used to signal writer closed
|
||||||
|
|
||||||
buf_size: int
|
buf_size: int
|
||||||
|
|
||||||
def as_msg(self):
|
def as_msg(self):
|
||||||
|
@ -59,62 +79,97 @@ class RBToken(Struct, frozen=True):
|
||||||
|
|
||||||
return RBToken(**msg)
|
return RBToken(**msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fds(self) -> tuple[int, int, int]:
|
||||||
|
'''
|
||||||
|
Useful for `pass_fds` params
|
||||||
|
|
||||||
|
'''
|
||||||
|
return (
|
||||||
|
self.write_eventfd,
|
||||||
|
self.wrap_eventfd,
|
||||||
|
self.eof_eventfd
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cm
|
@cm
|
||||||
def open_ringbuf(
|
def open_ringbuf(
|
||||||
shm_name: str,
|
shm_name: str,
|
||||||
buf_size: int = 10 * 1024,
|
buf_size: int = _DEFAULT_RB_SIZE,
|
||||||
write_efd_flags: int = 0,
|
) -> ContextManager[RBToken]:
|
||||||
wrap_efd_flags: int = 0
|
'''
|
||||||
) -> RBToken:
|
Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
|
||||||
|
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
|
||||||
|
|
||||||
|
'''
|
||||||
shm = SharedMemory(
|
shm = SharedMemory(
|
||||||
name=shm_name,
|
name=shm_name,
|
||||||
size=buf_size,
|
size=buf_size,
|
||||||
create=True
|
create=True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
token = RBToken(
|
with (
|
||||||
shm_name=shm_name,
|
EventFD(open_eventfd(), 'r') as write_event,
|
||||||
write_eventfd=open_eventfd(flags=write_efd_flags),
|
EventFD(open_eventfd(), 'r') as wrap_event,
|
||||||
wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
|
EventFD(open_eventfd(), 'r') as eof_event,
|
||||||
buf_size=buf_size
|
):
|
||||||
)
|
token = RBToken(
|
||||||
yield token
|
shm_name=shm_name,
|
||||||
|
write_eventfd=write_event.fd,
|
||||||
|
wrap_eventfd=wrap_event.fd,
|
||||||
|
eof_eventfd=eof_event.fd,
|
||||||
|
buf_size=buf_size
|
||||||
|
)
|
||||||
|
yield token
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
shm.unlink()
|
shm.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
Buffer = bytes | bytearray | memoryview
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
IPC Reliable Ring Buffer
|
||||||
|
|
||||||
|
`eventfd(2)` is used for wrap around sync, to signal writes to
|
||||||
|
the reader and end of stream.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
class RingBuffSender(trio.abc.SendStream):
|
class RingBuffSender(trio.abc.SendStream):
|
||||||
'''
|
'''
|
||||||
IPC Reliable Ring Buffer sender side implementation
|
Ring Buffer sender side implementation
|
||||||
|
|
||||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
Do not use directly! manage with `attach_to_ringbuf_sender`
|
||||||
writes to the reader.
|
after having opened a ringbuf context with `open_ringbuf`.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
start_ptr: int = 0,
|
cleanup: bool = False
|
||||||
):
|
):
|
||||||
token = RBToken.from_msg(token)
|
self._token = RBToken.from_msg(token)
|
||||||
self._shm = SharedMemory(
|
self._shm: SharedMemory | None = None
|
||||||
name=token.shm_name,
|
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||||
size=token.buf_size,
|
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||||
create=False
|
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||||
)
|
self._ptr = 0
|
||||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
|
||||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
self._cleanup = cleanup
|
||||||
self._ptr = start_ptr
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def name(self) -> str:
|
||||||
|
if not self._shm:
|
||||||
|
raise ValueError('shared memory not initialized yet!')
|
||||||
return self._shm.name
|
return self._shm.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
return self._shm.size
|
return self._token.buf_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ptr(self) -> int:
|
def ptr(self) -> int:
|
||||||
|
@ -128,73 +183,97 @@ class RingBuffSender(trio.abc.SendStream):
|
||||||
def wrap_fd(self) -> int:
|
def wrap_fd(self) -> int:
|
||||||
return self._wrap_event.fd
|
return self._wrap_event.fd
|
||||||
|
|
||||||
async def send_all(self, data: bytes | bytearray | memoryview):
|
async def _wait_wrap(self):
|
||||||
# while data is larger than the remaining buf
|
await self._wrap_event.read()
|
||||||
target_ptr = self.ptr + len(data)
|
|
||||||
while target_ptr > self.size:
|
|
||||||
# write all bytes that fit
|
|
||||||
remaining = self.size - self.ptr
|
|
||||||
self._shm.buf[self.ptr:] = data[:remaining]
|
|
||||||
# signal write and wait for reader wrap around
|
|
||||||
self._write_event.write(remaining)
|
|
||||||
await self._wrap_event.read()
|
|
||||||
|
|
||||||
# wrap around and trim already written bytes
|
async def send_all(self, data: Buffer):
|
||||||
self._ptr = 0
|
async with self._send_lock:
|
||||||
data = data[remaining:]
|
# while data is larger than the remaining buf
|
||||||
target_ptr = self._ptr + len(data)
|
target_ptr = self.ptr + len(data)
|
||||||
|
while target_ptr > self.size:
|
||||||
|
# write all bytes that fit
|
||||||
|
remaining = self.size - self.ptr
|
||||||
|
self._shm.buf[self.ptr:] = data[:remaining]
|
||||||
|
# signal write and wait for reader wrap around
|
||||||
|
self._write_event.write(remaining)
|
||||||
|
await self._wait_wrap()
|
||||||
|
|
||||||
# remaining data fits on buffer
|
# wrap around and trim already written bytes
|
||||||
self._shm.buf[self.ptr:target_ptr] = data
|
self._ptr = 0
|
||||||
self._write_event.write(len(data))
|
data = data[remaining:]
|
||||||
self._ptr = target_ptr
|
target_ptr = self._ptr + len(data)
|
||||||
|
|
||||||
|
# remaining data fits on buffer
|
||||||
|
self._shm.buf[self.ptr:target_ptr] = data
|
||||||
|
self._write_event.write(len(data))
|
||||||
|
self._ptr = target_ptr
|
||||||
|
|
||||||
async def wait_send_all_might_not_block(self):
|
async def wait_send_all_might_not_block(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def aclose(self):
|
def open(self):
|
||||||
self._write_event.close()
|
self._shm = SharedMemory(
|
||||||
self._wrap_event.close()
|
name=self._token.shm_name,
|
||||||
self._shm.close()
|
size=self._token.buf_size,
|
||||||
|
create=False
|
||||||
async def __aenter__(self):
|
)
|
||||||
self._write_event.open()
|
self._write_event.open()
|
||||||
self._wrap_event.open()
|
self._wrap_event.open()
|
||||||
|
self._eof_event.open()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._eof_event.write(
|
||||||
|
self._ptr if self._ptr > 0 else self.size
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._cleanup:
|
||||||
|
self._write_event.close()
|
||||||
|
self._wrap_event.close()
|
||||||
|
self._eof_event.close()
|
||||||
|
self._shm.close()
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
async with self._send_lock:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.open()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RingBuffReceiver(trio.abc.ReceiveStream):
|
class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||||
'''
|
'''
|
||||||
IPC Reliable Ring Buffer receiver side implementation
|
Ring Buffer receiver side implementation
|
||||||
|
|
||||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
Do not use directly! manage with `attach_to_ringbuf_receiver`
|
||||||
writes to the reader.
|
after having opened a ringbuf context with `open_ringbuf`.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
start_ptr: int = 0,
|
cleanup: bool = True,
|
||||||
flags: int = 0
|
|
||||||
):
|
):
|
||||||
token = RBToken.from_msg(token)
|
self._token = RBToken.from_msg(token)
|
||||||
self._shm = SharedMemory(
|
self._shm: SharedMemory | None = None
|
||||||
name=token.shm_name,
|
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||||
size=token.buf_size,
|
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||||
create=False
|
self._eof_event = EventFD(self._token.eof_eventfd, 'r')
|
||||||
)
|
self._ptr: int = 0
|
||||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
self._write_ptr: int = 0
|
||||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
self._end_ptr: int = -1
|
||||||
self._ptr = start_ptr
|
|
||||||
self._flags = flags
|
self._cleanup: bool = cleanup
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def name(self) -> str:
|
||||||
|
if not self._shm:
|
||||||
|
raise ValueError('shared memory not initialized yet!')
|
||||||
return self._shm.name
|
return self._shm.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
return self._shm.size
|
return self._token.buf_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ptr(self) -> int:
|
def ptr(self) -> int:
|
||||||
|
@ -208,46 +287,368 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||||
def wrap_fd(self) -> int:
|
def wrap_fd(self) -> int:
|
||||||
return self._wrap_event.fd
|
return self._wrap_event.fd
|
||||||
|
|
||||||
async def receive_some(
|
async def _eof_monitor_task(self):
|
||||||
self,
|
'''
|
||||||
max_bytes: int | None = None,
|
Long running EOF event monitor, automatically run in bg by
|
||||||
nb_timeout: float = 0.1
|
`attach_to_ringbuf_receiver` context manager, if EOF event
|
||||||
) -> memoryview:
|
is set its value will be the end pointer (highest valid
|
||||||
# if non blocking eventfd enabled, do polling
|
index to be read from buf, after setting the `self._end_ptr`
|
||||||
# until next write, this allows signal handling
|
we close the write event which should cancel any blocked
|
||||||
if self._flags | EFD_NONBLOCK:
|
`self._write_event.read()`s on it.
|
||||||
delta = None
|
|
||||||
while delta is None:
|
'''
|
||||||
|
try:
|
||||||
|
self._end_ptr = await self._eof_event.read()
|
||||||
|
self._write_event.close()
|
||||||
|
|
||||||
|
except EFDReadCancelled:
|
||||||
|
...
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||||
|
'''
|
||||||
|
Receive up to `max_bytes`, if no `max_bytes` is provided
|
||||||
|
a reasonable default is used.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if max_bytes is None:
|
||||||
|
max_bytes: int = _DEFAULT_RB_SIZE
|
||||||
|
|
||||||
|
if max_bytes < 1:
|
||||||
|
raise ValueError("max_bytes must be >= 1")
|
||||||
|
|
||||||
|
# delta is remaining bytes we havent read
|
||||||
|
delta = self._write_ptr - self._ptr
|
||||||
|
if delta == 0:
|
||||||
|
# we have read all we can, see if new data is available
|
||||||
|
if self._end_ptr < 0:
|
||||||
|
# if we havent been signaled about EOF yet
|
||||||
try:
|
try:
|
||||||
delta = await self._write_event.read()
|
delta = await self._write_event.read()
|
||||||
|
self._write_ptr += delta
|
||||||
|
|
||||||
except OSError as e:
|
except EFDReadCancelled:
|
||||||
if e.errno == 'EAGAIN':
|
# while waiting for new data `self._write_event` was closed
|
||||||
continue
|
# this means writer signaled EOF
|
||||||
|
if self._end_ptr > 0:
|
||||||
|
# final self._write_ptr modification and recalculate delta
|
||||||
|
self._write_ptr = self._end_ptr
|
||||||
|
delta = self._end_ptr - self._ptr
|
||||||
|
|
||||||
raise e
|
else:
|
||||||
|
# shouldnt happen cause self._eof_monitor_task always sets
|
||||||
|
# self._end_ptr before closing self._write_event
|
||||||
|
raise InternalError(
|
||||||
|
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
delta = await self._write_event.read()
|
# no more bytes to read and self._end_ptr set, EOF reached
|
||||||
|
return b''
|
||||||
|
|
||||||
|
# dont overflow caller
|
||||||
|
delta = min(delta, max_bytes)
|
||||||
|
|
||||||
|
target_ptr = self._ptr + delta
|
||||||
|
|
||||||
# fetch next segment and advance ptr
|
# fetch next segment and advance ptr
|
||||||
next_ptr = self._ptr + delta
|
segment = bytes(self._shm.buf[self._ptr:target_ptr])
|
||||||
segment = self._shm.buf[self._ptr:next_ptr]
|
self._ptr = target_ptr
|
||||||
self._ptr = next_ptr
|
|
||||||
|
|
||||||
if self.ptr == self.size:
|
if self._ptr == self.size:
|
||||||
# reached the end, signal wrap around
|
# reached the end, signal wrap around
|
||||||
self._ptr = 0
|
self._ptr = 0
|
||||||
|
self._write_ptr = 0
|
||||||
self._wrap_event.write(1)
|
self._wrap_event.write(1)
|
||||||
|
|
||||||
return segment
|
return segment
|
||||||
|
|
||||||
async def aclose(self):
|
def open(self):
|
||||||
self._write_event.close()
|
self._shm = SharedMemory(
|
||||||
self._wrap_event.close()
|
name=self._token.shm_name,
|
||||||
self._shm.close()
|
size=self._token.buf_size,
|
||||||
|
create=False
|
||||||
async def __aenter__(self):
|
)
|
||||||
self._write_event.open()
|
self._write_event.open()
|
||||||
self._wrap_event.open()
|
self._wrap_event.open()
|
||||||
|
self._eof_event.open()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._cleanup:
|
||||||
|
self._write_event.close()
|
||||||
|
self._wrap_event.close()
|
||||||
|
self._eof_event.close()
|
||||||
|
self._shm.close()
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.open()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_receiver(
|
||||||
|
token: RBToken,
|
||||||
|
cleanup: bool = True
|
||||||
|
) -> AsyncContextManager[RingBuffReceiver]:
|
||||||
|
'''
|
||||||
|
Attach a RingBuffReceiver from a previously opened
|
||||||
|
RBToken.
|
||||||
|
|
||||||
|
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
RingBuffReceiver(
|
||||||
|
token,
|
||||||
|
cleanup=cleanup
|
||||||
|
) as receiver
|
||||||
|
):
|
||||||
|
n.start_soon(receiver._eof_monitor_task)
|
||||||
|
yield receiver
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_sender(
|
||||||
|
token: RBToken,
|
||||||
|
cleanup: bool = True
|
||||||
|
) -> AsyncContextManager[RingBuffSender]:
|
||||||
|
'''
|
||||||
|
Attach a RingBuffSender from a previously opened
|
||||||
|
RBToken.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with RingBuffSender(
|
||||||
|
token,
|
||||||
|
cleanup=cleanup
|
||||||
|
) as sender:
|
||||||
|
yield sender
|
||||||
|
|
||||||
|
|
||||||
|
@cm
|
||||||
|
def open_ringbuf_pair(
|
||||||
|
name: str,
|
||||||
|
buf_size: int = _DEFAULT_RB_SIZE
|
||||||
|
) -> ContextManager[tuple(RBToken, RBToken)]:
|
||||||
|
'''
|
||||||
|
Handle resources for a ringbuf pair to be used for
|
||||||
|
bidirectional messaging.
|
||||||
|
|
||||||
|
'''
|
||||||
|
with (
|
||||||
|
open_ringbuf(
|
||||||
|
name + '.pair0',
|
||||||
|
buf_size=buf_size
|
||||||
|
) as token_0,
|
||||||
|
|
||||||
|
open_ringbuf(
|
||||||
|
name + '.pair1',
|
||||||
|
buf_size=buf_size
|
||||||
|
) as token_1
|
||||||
|
):
|
||||||
|
yield token_0, token_1
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_stream(
|
||||||
|
token_in: RBToken,
|
||||||
|
token_out: RBToken,
|
||||||
|
cleanup_in: bool = True,
|
||||||
|
cleanup_out: bool = True
|
||||||
|
) -> AsyncContextManager[trio.StapledStream]:
|
||||||
|
'''
|
||||||
|
Attach a trio.StapledStream from a previously opened
|
||||||
|
ringbuf pair.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
attach_to_ringbuf_receiver(
|
||||||
|
token_in,
|
||||||
|
cleanup=cleanup_in
|
||||||
|
) as receiver,
|
||||||
|
attach_to_ringbuf_sender(
|
||||||
|
token_out,
|
||||||
|
cleanup=cleanup_out
|
||||||
|
) as sender,
|
||||||
|
):
|
||||||
|
yield trio.StapledStream(sender, receiver)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
|
||||||
|
'''
|
||||||
|
In order to guarantee full messages are received, all bytes
|
||||||
|
sent by `RingBuffBytesSender` are preceded with a 4 byte header
|
||||||
|
which decodes into a uint32 indicating the actual size of the
|
||||||
|
next payload.
|
||||||
|
|
||||||
|
Optional batch mode:
|
||||||
|
|
||||||
|
If `batch_size` > 1 messages wont get sent immediately but will be
|
||||||
|
stored until `batch_size` messages are pending, then it will send
|
||||||
|
them all at once.
|
||||||
|
|
||||||
|
`batch_size` can be changed dynamically but always call, `flush()`
|
||||||
|
right before.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sender: RingBuffSender,
|
||||||
|
batch_size: int = 1
|
||||||
|
):
|
||||||
|
self._sender = sender
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self._batch_msg_len = 0
|
||||||
|
self._batch: bytes = b''
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
await self._sender.send_all(self._batch)
|
||||||
|
self._batch = b''
|
||||||
|
self._batch_msg_len = 0
|
||||||
|
|
||||||
|
async def send(self, value: bytes) -> None:
|
||||||
|
msg: bytes = struct.pack("<I", len(value)) + value
|
||||||
|
if self.batch_size == 1:
|
||||||
|
await self._sender.send_all(msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._batch += msg
|
||||||
|
self._batch_msg_len += 1
|
||||||
|
if self._batch_msg_len == self.batch_size:
|
||||||
|
await self.flush()
|
||||||
|
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._sender.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
|
||||||
|
'''
|
||||||
|
See `RingBuffBytesSender` docstring.
|
||||||
|
|
||||||
|
A `tricycle.BufferedReceiveStream` is used for the
|
||||||
|
`receive_exactly` API.
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
receiver: RingBuffReceiver
|
||||||
|
):
|
||||||
|
self._receiver = receiver
|
||||||
|
|
||||||
|
async def _receive_exactly(self, num_bytes: int) -> bytes:
|
||||||
|
'''
|
||||||
|
Fetch bytes from receiver until we read exactly `num_bytes`
|
||||||
|
or end of stream is signaled.
|
||||||
|
|
||||||
|
'''
|
||||||
|
payload = b''
|
||||||
|
while len(payload) < num_bytes:
|
||||||
|
remaining = num_bytes - len(payload)
|
||||||
|
|
||||||
|
new_bytes = await self._receiver.receive_some(
|
||||||
|
max_bytes=remaining
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_bytes == b'':
|
||||||
|
raise trio.EndOfChannel
|
||||||
|
|
||||||
|
payload += new_bytes
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def receive(self) -> bytes:
|
||||||
|
header: bytes = await self._receive_exactly(4)
|
||||||
|
size: int
|
||||||
|
size, = struct.unpack("<I", header)
|
||||||
|
if size == 0:
|
||||||
|
raise trio.EndOfChannel
|
||||||
|
return await self._receive_exactly(size)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._receiver.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_rchannel(
|
||||||
|
token: RBToken,
|
||||||
|
cleanup: bool = True
|
||||||
|
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
||||||
|
'''
|
||||||
|
Attach a RingBuffBytesReceiver from a previously opened
|
||||||
|
RBToken.
|
||||||
|
'''
|
||||||
|
async with attach_to_ringbuf_receiver(
|
||||||
|
token, cleanup=cleanup
|
||||||
|
) as receiver:
|
||||||
|
yield RingBuffBytesReceiver(receiver)
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_schannel(
|
||||||
|
token: RBToken,
|
||||||
|
cleanup: bool = True,
|
||||||
|
batch_size: int = 1,
|
||||||
|
) -> AsyncContextManager[RingBuffBytesSender]:
|
||||||
|
'''
|
||||||
|
Attach a RingBuffBytesSender from a previously opened
|
||||||
|
RBToken.
|
||||||
|
'''
|
||||||
|
async with attach_to_ringbuf_sender(
|
||||||
|
token, cleanup=cleanup
|
||||||
|
) as sender:
|
||||||
|
yield RingBuffBytesSender(sender, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
class RingBuffChannel(trio.abc.Channel[bytes]):
|
||||||
|
'''
|
||||||
|
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
|
||||||
|
in order to expose the bidirectional `trio.abc.Channel` API.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sender: RingBuffBytesSender,
|
||||||
|
receiver: RingBuffBytesReceiver
|
||||||
|
):
|
||||||
|
self._sender = sender
|
||||||
|
self._receiver = receiver
|
||||||
|
|
||||||
|
async def send(self, value: bytes):
|
||||||
|
await self._sender.send(value)
|
||||||
|
|
||||||
|
async def receive(self) -> bytes:
|
||||||
|
return await self._receiver.receive()
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
await self._receiver.aclose()
|
||||||
|
await self._sender.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def attach_to_ringbuf_channel(
|
||||||
|
token_in: RBToken,
|
||||||
|
token_out: RBToken,
|
||||||
|
cleanup_in: bool = True,
|
||||||
|
cleanup_out: bool = True
|
||||||
|
) -> AsyncContextManager[RingBuffChannel]:
|
||||||
|
'''
|
||||||
|
Attach to an already opened ringbuf pair and return
|
||||||
|
a `RingBuffChannel`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
attach_to_ringbuf_rchannel(
|
||||||
|
token_in,
|
||||||
|
cleanup=cleanup_in
|
||||||
|
) as receiver,
|
||||||
|
attach_to_ringbuf_schannel(
|
||||||
|
token_out,
|
||||||
|
cleanup=cleanup_out
|
||||||
|
) as sender,
|
||||||
|
):
|
||||||
|
yield RingBuffChannel(sender, receiver)
|
||||||
|
|
|
@ -73,7 +73,7 @@ class MsgTransport(Protocol[MsgType]):
|
||||||
# eventual msg definition/types?
|
# eventual msg definition/types?
|
||||||
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||||
|
|
||||||
stream: trio.SocketStream
|
stream: trio.abc.Stream
|
||||||
drained: list[MsgType]
|
drained: list[MsgType]
|
||||||
|
|
||||||
address_type: ClassVar[Type[Address]]
|
address_type: ClassVar[Type[Address]]
|
||||||
|
|
Loading…
Reference in New Issue