283 lines
7.6 KiB
Python
283 lines
7.6 KiB
Python
from contextlib import asynccontextmanager as acm
|
|
|
|
import trio
|
|
import tractor
|
|
import msgspec
|
|
|
|
from tractor.ipc import (
|
|
attach_to_ringbuf_receiver,
|
|
attach_to_ringbuf_sender
|
|
)
|
|
from tractor.ipc._ringbuf._pubsub import (
|
|
open_ringbuf_publisher,
|
|
open_ringbuf_subscriber
|
|
)
|
|
|
|
import tractor.ipc._ringbuf._ringd as ringd
|
|
|
|
|
|
log = tractor.log.get_console_log(level='info')
|
|
|
|
|
|
@tractor.context
|
|
async def recv_child(
|
|
ctx: tractor.Context,
|
|
ring_name: str
|
|
):
|
|
async with (
|
|
ringd.open_ringbuf(ring_name) as token,
|
|
|
|
attach_to_ringbuf_receiver(token) as chan,
|
|
):
|
|
await ctx.started()
|
|
async for msg in chan:
|
|
log.info(f'received {int.from_bytes(msg)}')
|
|
|
|
|
|
@tractor.context
|
|
async def send_child(
|
|
ctx: tractor.Context,
|
|
ring_name: str
|
|
):
|
|
async with (
|
|
ringd.open_ringbuf(ring_name) as token,
|
|
|
|
attach_to_ringbuf_sender(token) as chan,
|
|
):
|
|
await ctx.started()
|
|
for i in range(100):
|
|
await chan.send(i.to_bytes(4))
|
|
log.info(f'sent {i}')
|
|
|
|
|
|
|
|
def test_ringd():
|
|
'''
|
|
Spawn ringd actor and two childs that access same ringbuf through ringd.
|
|
|
|
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
|
them as sender and receiver.
|
|
|
|
'''
|
|
async def main():
|
|
async with (
|
|
tractor.open_nursery() as an,
|
|
|
|
ringd.open_ringd(
|
|
loglevel='info'
|
|
)
|
|
):
|
|
recv_portal = await an.start_actor(
|
|
'recv',
|
|
enable_modules=[__name__]
|
|
)
|
|
send_portal = await an.start_actor(
|
|
'send',
|
|
enable_modules=[__name__]
|
|
)
|
|
|
|
async with (
|
|
recv_portal.open_context(
|
|
recv_child,
|
|
ring_name='ring'
|
|
) as (rctx, _),
|
|
|
|
send_portal.open_context(
|
|
send_child,
|
|
ring_name='ring'
|
|
) as (sctx, _),
|
|
):
|
|
...
|
|
|
|
await an.cancel()
|
|
|
|
trio.run(main)
|
|
|
|
|
|
class Struct(msgspec.Struct):
|
|
|
|
def encode(self) -> bytes:
|
|
return msgspec.msgpack.encode(self)
|
|
|
|
|
|
class AddChannelMsg(Struct, frozen=True, tag=True):
|
|
name: str
|
|
|
|
|
|
class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
|
name: str
|
|
|
|
|
|
class RangeMsg(Struct, frozen=True, tag=True):
|
|
size: int
|
|
|
|
|
|
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
|
|
|
|
|
@tractor.context
|
|
async def subscriber_child(ctx: tractor.Context):
|
|
await ctx.started()
|
|
async with (
|
|
open_ringbuf_subscriber(guarantee_order=True) as subs,
|
|
trio.open_nursery() as n,
|
|
ctx.open_stream() as stream
|
|
):
|
|
range_msg = None
|
|
range_event = trio.Event()
|
|
range_scope = trio.CancelScope()
|
|
|
|
async def _control_listen_task():
|
|
nonlocal range_msg, range_event
|
|
async for msg in stream:
|
|
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
|
match msg:
|
|
case AddChannelMsg():
|
|
await subs.add_channel(msg.name, must_exist=False)
|
|
|
|
case RemoveChannelMsg():
|
|
await subs.remove_channel(msg.name)
|
|
|
|
case RangeMsg():
|
|
range_msg = msg
|
|
range_event.set()
|
|
|
|
await stream.send(b'ack')
|
|
|
|
range_scope.cancel()
|
|
|
|
n.start_soon(_control_listen_task)
|
|
|
|
with range_scope:
|
|
while True:
|
|
await range_event.wait()
|
|
range_event = trio.Event()
|
|
for i in range(range_msg.size):
|
|
recv = int.from_bytes(await subs.receive())
|
|
if recv != i:
|
|
raise AssertionError(
|
|
f'received: {recv} expected: {i}'
|
|
)
|
|
|
|
log.info(f'received: {recv}')
|
|
|
|
await stream.send(b'valid range')
|
|
log.info('finished range')
|
|
|
|
log.info('subscriber exit')
|
|
|
|
|
|
@tractor.context
|
|
async def publisher_child(ctx: tractor.Context):
|
|
await ctx.started()
|
|
async with (
|
|
open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub,
|
|
ctx.open_stream() as stream
|
|
):
|
|
async for msg in stream:
|
|
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
|
match msg:
|
|
case AddChannelMsg():
|
|
await pub.add_channel(msg.name, must_exist=True)
|
|
|
|
case RemoveChannelMsg():
|
|
await pub.remove_channel(msg.name)
|
|
|
|
case RangeMsg():
|
|
for i in range(msg.size):
|
|
await pub.send(i.to_bytes(4))
|
|
log.info(f'sent {i}')
|
|
|
|
await stream.send(b'ack')
|
|
|
|
log.info('publisher exit')
|
|
|
|
|
|
|
|
def test_pubsub():
|
|
'''
|
|
Spawn ringd actor and two childs that access same ringbuf through ringd.
|
|
|
|
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
|
them as sender and receiver.
|
|
|
|
'''
|
|
async def main():
|
|
async with (
|
|
tractor.open_nursery(
|
|
loglevel='info',
|
|
# debug_mode=True,
|
|
# enable_stack_on_sig=True
|
|
) as an,
|
|
|
|
ringd.open_ringd()
|
|
):
|
|
recv_portal = await an.start_actor(
|
|
'recv',
|
|
enable_modules=[__name__]
|
|
)
|
|
send_portal = await an.start_actor(
|
|
'send',
|
|
enable_modules=[__name__]
|
|
)
|
|
|
|
async with (
|
|
recv_portal.open_context(subscriber_child) as (rctx, _),
|
|
rctx.open_stream() as recv_stream,
|
|
send_portal.open_context(publisher_child) as (sctx, _),
|
|
sctx.open_stream() as send_stream,
|
|
):
|
|
async def send_wait_ack(msg: bytes):
|
|
await recv_stream.send(msg)
|
|
ack = await recv_stream.receive()
|
|
assert ack == b'ack'
|
|
|
|
await send_stream.send(msg)
|
|
ack = await send_stream.receive()
|
|
assert ack == b'ack'
|
|
|
|
async def add_channel(name: str):
|
|
await send_wait_ack(AddChannelMsg(name=name).encode())
|
|
|
|
async def remove_channel(name: str):
|
|
await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
|
|
|
async def send_range(size: int):
|
|
await send_wait_ack(RangeMsg(size=size).encode())
|
|
range_ack = await recv_stream.receive()
|
|
assert range_ack == b'valid range'
|
|
|
|
# simple test, open one channel and send 0..100 range
|
|
ring_name = 'ring-first'
|
|
await add_channel(ring_name)
|
|
await send_range(100)
|
|
await remove_channel(ring_name)
|
|
|
|
# redo
|
|
ring_name = 'ring-redo'
|
|
await add_channel(ring_name)
|
|
await send_range(100)
|
|
await remove_channel(ring_name)
|
|
|
|
# try using same ring name
|
|
await add_channel(ring_name)
|
|
await send_range(100)
|
|
await remove_channel(ring_name)
|
|
|
|
# multi chan test
|
|
ring_names = []
|
|
for i in range(3):
|
|
ring_names.append(f'multi-ring-{i}')
|
|
|
|
for name in ring_names:
|
|
await add_channel(name)
|
|
|
|
await send_range(1000)
|
|
|
|
for name in ring_names:
|
|
await remove_channel(name)
|
|
|
|
await an.cancel()
|
|
|
|
trio.run(main)
|