Improve ringd ringbuf lifecycle

Unlink sock after use in fdshare
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-04 02:41:50 -03:00
parent 3568ba5d5d
commit bebd327023
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 293 additions and 44 deletions

View File

@ -1,9 +1,14 @@
import trio import trio
import tractor import tractor
import msgspec
from tractor.ipc import ( from tractor.ipc import (
attach_to_ringbuf_rchannel, attach_to_ringbuf_receiver,
attach_to_ringbuf_schannel attach_to_ringbuf_sender
)
from tractor.ipc._ringbuf._pubsub import (
open_ringbuf_publisher,
open_ringbuf_subscriber
) )
import tractor.ipc._ringbuf._ringd as ringd import tractor.ipc._ringbuf._ringd as ringd
@ -20,7 +25,7 @@ async def recv_child(
async with ( async with (
ringd.open_ringbuf(ring_name) as token, ringd.open_ringbuf(ring_name) as token,
attach_to_ringbuf_rchannel(token) as chan, attach_to_ringbuf_receiver(token) as chan,
): ):
await ctx.started() await ctx.started()
async for msg in chan: async for msg in chan:
@ -35,7 +40,7 @@ async def send_child(
async with ( async with (
ringd.open_ringbuf(ring_name) as token, ringd.open_ringbuf(ring_name) as token,
attach_to_ringbuf_schannel(token) as chan, attach_to_ringbuf_sender(token) as chan,
): ):
await ctx.started() await ctx.started()
for i in range(100): for i in range(100):
@ -45,6 +50,13 @@ async def send_child(
def test_ringd(): def test_ringd():
'''
Spawn ringd actor and two childs that access same ringbuf through ringd.
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
them as sender and receiver.
'''
async def main(): async def main():
async with ( async with (
tractor.open_nursery() as an, tractor.open_nursery() as an,
@ -73,9 +85,194 @@ def test_ringd():
ring_name='ring' ring_name='ring'
) as (sctx, _), ) as (sctx, _),
): ):
await rctx.wait_for_result() ...
await sctx.wait_for_result()
await an.cancel() await an.cancel()
trio.run(main) trio.run(main)
# class Struct(msgspec.Struct):
#
# def encode(self) -> bytes:
# return msgspec.msgpack.encode(self)
#
#
# class AddChannelMsg(Struct, frozen=True, tag=True):
# name: str
#
#
# class RemoveChannelMsg(Struct, frozen=True, tag=True):
# name: str
#
#
# class RangeMsg(Struct, frozen=True, tag=True):
# start: int
# end: int
#
#
# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
#
#
# @tractor.context
# async def subscriber_child(ctx: tractor.Context):
# await ctx.started()
# async with (
# open_ringbuf_subscriber(guarantee_order=True) as subs,
# trio.open_nursery() as n,
# ctx.open_stream() as stream
# ):
# range_msg = None
# range_event = trio.Event()
# range_scope = trio.CancelScope()
#
# async def _control_listen_task():
# nonlocal range_msg, range_event
# async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg:
# case AddChannelMsg():
# await subs.add_channel(msg.name, must_exist=False)
#
# case RemoveChannelMsg():
# await subs.remove_channel(msg.name)
#
# case RangeMsg():
# range_msg = msg
# range_event.set()
#
# await stream.send(b'ack')
#
# range_scope.cancel()
#
# n.start_soon(_control_listen_task)
#
# with range_scope:
# while True:
# await range_event.wait()
# range_event = trio.Event()
# for i in range(range_msg.start, range_msg.end):
# recv = int.from_bytes(await subs.receive())
# # if recv != i:
# # raise AssertionError(
# # f'received: {recv} expected: {i}'
# # )
#
# log.info(f'received: {recv} expected: {i}')
#
# await stream.send(b'valid range')
# log.info('FINISHED RANGE')
#
# log.info('subscriber exit')
#
#
# @tractor.context
# async def publisher_child(ctx: tractor.Context):
# await ctx.started()
# async with (
# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
# ctx.open_stream() as stream
# ):
# abs_index = 0
# async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg:
# case AddChannelMsg():
# await pub.add_channel(msg.name, must_exist=True)
#
# case RemoveChannelMsg():
# await pub.remove_channel(msg.name)
#
# case RangeMsg():
# for i in range(msg.start, msg.end):
# await pub.send(i.to_bytes(4))
# log.info(f'sent {i}, index: {abs_index}')
# abs_index += 1
#
# await stream.send(b'ack')
#
# log.info('publisher exit')
#
#
#
# def test_pubsub():
# '''
# Spawn ringd actor and two childs that access same ringbuf through ringd.
#
# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
# them as sender and receiver.
#
# '''
# async def main():
# async with (
# tractor.open_nursery(
# loglevel='info',
# # debug_mode=True,
# # enable_stack_on_sig=True
# ) as an,
#
# ringd.open_ringd()
# ):
# recv_portal = await an.start_actor(
# 'recv',
# enable_modules=[__name__]
# )
# send_portal = await an.start_actor(
# 'send',
# enable_modules=[__name__]
# )
#
# async with (
# recv_portal.open_context(subscriber_child) as (rctx, _),
# rctx.open_stream() as recv_stream,
# send_portal.open_context(publisher_child) as (sctx, _),
# sctx.open_stream() as send_stream,
# ):
# async def send_wait_ack(msg: bytes):
# await recv_stream.send(msg)
# ack = await recv_stream.receive()
# assert ack == b'ack'
#
# await send_stream.send(msg)
# ack = await send_stream.receive()
# assert ack == b'ack'
#
# async def add_channel(name: str):
# await send_wait_ack(AddChannelMsg(name=name).encode())
#
# async def remove_channel(name: str):
# await send_wait_ack(RemoveChannelMsg(name=name).encode())
#
# async def send_range(start: int, end: int):
# await send_wait_ack(RangeMsg(start=start, end=end).encode())
# range_ack = await recv_stream.receive()
# assert range_ack == b'valid range'
#
# # simple test, open one channel and send 0..100 range
# ring_name = 'ring-first'
# await add_channel(ring_name)
# await send_range(0, 100)
# await remove_channel(ring_name)
#
# # redo
# ring_name = 'ring-redo'
# await add_channel(ring_name)
# await send_range(0, 100)
# await remove_channel(ring_name)
#
# # multi chan test
# ring_names = []
# for i in range(3):
# ring_names.append(f'multi-ring-{i}')
#
# for name in ring_names:
# await add_channel(name)
#
# await send_range(0, 300)
#
# for name in ring_names:
# await remove_channel(name)
#
# await an.cancel()
#
# trio.run(main)

View File

@ -29,6 +29,7 @@ from pathlib import Path
from contextlib import ( from contextlib import (
asynccontextmanager as acm asynccontextmanager as acm
) )
from dataclasses import dataclass
import trio import trio
import tractor import tractor
@ -42,12 +43,41 @@ log = tractor.log.get_logger(__name__)
# log = tractor.log.get_console_log(level='info') # log = tractor.log.get_console_log(level='info')
class RingNotFound(Exception):
...
_ringd_actor_name = 'ringd' _ringd_actor_name = 'ringd'
_root_key = _ringd_actor_name + f'-{os.getpid()}' _root_key = _ringd_actor_name + f'-{os.getpid()}'
_rings: dict[str, RBToken] = {}
@dataclass
class RingInfo:
token: RBToken
creator: str
unlink: trio.Event()
_rings: dict[str, RingInfo] = {}
def _maybe_get_ring(name: str) -> RingInfo | None:
if name in _rings:
return _rings[name]
return None
def _insert_ring(name: str, info: RingInfo):
_rings[name] = info
def _destroy_ring(name: str):
del _rings[name]
async def _attach_to_ring( async def _attach_to_ring(
ringd_pid: int,
ring_name: str ring_name: str
) -> RBToken: ) -> RBToken:
actor = tractor.current_actor() actor = tractor.current_actor()
@ -56,7 +86,7 @@ async def _attach_to_ring(
sock_path = str( sock_path = str(
Path(tempfile.gettempdir()) Path(tempfile.gettempdir())
/ /
f'{os.getpid()}-pass-ring-fds-{ring_name}-to-{actor.name}.sock' f'ringd-{ringd_pid}-{ring_name}-to-{actor.name}.sock'
) )
log.info(f'trying to attach to ring {ring_name}...') log.info(f'trying to attach to ring {ring_name}...')
@ -94,8 +124,12 @@ async def _pass_fds(
sock_path: str sock_path: str
): ):
global _rings global _rings
info = _maybe_get_ring(name)
token = _rings[name] if not info:
raise RingNotFound(f'Ring \"{name}\" not found!')
token = info.token
async with send_fds(token.fds, sock_path): async with send_fds(token.fds, sock_path):
log.info(f'connected to {sock_path} for fd passing') log.info(f'connected to {sock_path} for fd passing')
@ -109,48 +143,58 @@ async def _pass_fds(
@tractor.context @tractor.context
async def _open_ringbuf( async def _open_ringbuf(
ctx: tractor.Context, ctx: tractor.Context,
caller: str,
name: str, name: str,
buf_size: int = 10 * 1024,
must_exist: bool = False, must_exist: bool = False,
buf_size: int = 10 * 1024
): ):
global _root_key, _rings global _root_key, _rings
log.info(f'maybe open ring {name} from {caller}, must_exist = {must_exist}')
info = _maybe_get_ring(name)
if info:
log.info(f'ring {name} exists, {caller} attached')
await ctx.started(os.getpid())
teardown = trio.Event()
async def _teardown_listener(task_status=trio.TASK_STATUS_IGNORED):
async with ctx.open_stream() as stream: async with ctx.open_stream() as stream:
task_status.started()
await stream.receive() await stream.receive()
teardown.set()
log.info(f'maybe open ring {name}, must_exist = {must_exist}') info.unlink.set()
token = _rings.get(name, None) log.info(f'{caller} detached from ring {name}')
async with trio.open_nursery() as n: return
if token:
log.info(f'ring {name} exists')
await ctx.started()
await n.start(_teardown_listener)
await teardown.wait()
return
if must_exist: if must_exist:
raise FileNotFoundError( raise RingNotFound(
f'Tried to open_ringbuf but it doesn\'t exist: {name}' f'Tried to open_ringbuf but it doesn\'t exist: {name}'
)
with ringbuf.open_ringbuf(
_root_key + name,
buf_size=buf_size
) as token:
unlink_event = trio.Event()
_insert_ring(
name,
RingInfo(
token=token,
creator=caller,
unlink=unlink_event,
) )
)
log.info(f'ring {name} created by {caller}')
await ctx.started(os.getpid())
with ringbuf.open_ringbuf( async with ctx.open_stream() as stream:
_root_key + name, await stream.receive()
buf_size=buf_size
) as token:
_rings[name] = token
log.info(f'ring {name} created')
await ctx.started()
await n.start(_teardown_listener)
await teardown.wait()
del _rings[name]
log.info(f'ring {name} destroyed') await unlink_event.wait()
_destroy_ring(name)
log.info(f'ring {name} destroyed by {caller}')
@acm @acm
@ -174,22 +218,28 @@ async def wait_for_ringd() -> tractor.Portal:
@acm @acm
async def open_ringbuf( async def open_ringbuf(
name: str, name: str,
buf_size: int = 10 * 1024,
must_exist: bool = False, must_exist: bool = False,
buf_size: int = 10 * 1024
) -> RBToken: ) -> RBToken:
actor = tractor.current_actor()
async with ( async with (
wait_for_ringd() as ringd, wait_for_ringd() as ringd,
ringd.open_context( ringd.open_context(
_open_ringbuf, _open_ringbuf,
caller=actor.name,
name=name, name=name,
must_exist=must_exist, buf_size=buf_size,
buf_size=buf_size must_exist=must_exist
) as (rd_ctx, _), ) as (rd_ctx, ringd_pid),
rd_ctx.open_stream() as stream,
rd_ctx.open_stream() as _stream,
): ):
token = await _attach_to_ring(name) token = await _attach_to_ring(ringd_pid, name)
log.info(f'attached to {token}') log.info(f'attached to {token}')
yield token yield token
await stream.send(b'bye')

View File

@ -19,6 +19,7 @@ Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and
cpython impl: cpython impl:
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138 https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
''' '''
import os
import array import array
from typing import AsyncContextManager from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
@ -73,6 +74,7 @@ async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
conn.close() conn.close()
sock.close() sock.close()
os.unlink(sock_path)
async def recv_fds(sock_path: str, amount: int) -> tuple: async def recv_fds(sock_path: str, amount: int) -> tuple: