Add optional msgpack encoder & decoder to ringbuf apis

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-22 01:45:44 -03:00
parent 86e09a80f4
commit 8799cf3b78
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 67 additions and 14 deletions

View File

@ -20,6 +20,7 @@ IPC Reliable RingBuffer implementation
from __future__ import annotations
import struct
from typing import (
TypeVar,
ContextManager,
AsyncContextManager
)
@ -34,6 +35,10 @@ from msgspec import (
Struct,
to_builtins
)
from msgspec.msgpack import (
Encoder,
Decoder,
)
from tractor.log import get_logger
from tractor._exceptions import (
@ -277,7 +282,10 @@ next full payload.
'''
class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
PayloadT = TypeVar('PayloadT')
class RingBufferSendChannel(trio.abc.SendChannel[PayloadT]):
'''
Ring Buffer sender side implementation
@ -298,7 +306,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self,
token: RBToken,
batch_size: int = 1,
cleanup: bool = False
cleanup: bool = False,
encoder: Encoder | None = None
):
self._token = RBToken.from_msg(token)
self.batch_size = batch_size
@ -319,6 +328,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
# close shm & fds on exit?
self._cleanup: bool = cleanup
self._enc: Encoder | None = encoder
# have we closed this ringbuf?
# set to `False` on `.open()`
self._is_closed: bool = True
@ -415,15 +426,22 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
if new_batch_size:
self.batch_size = new_batch_size
async def send(self, value: bytes) -> None:
async def send(self, value: PayloadT) -> None:
if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
raw_value: bytes = (
value
if isinstance(value, bytes)
else
self._enc.encode(value)
)
async with self._send_lock:
msg: bytes = struct.pack("<I", len(value)) + value
msg: bytes = struct.pack("<I", len(raw_value)) + raw_value
if self.batch_size == 1:
if len(self._batch) > 0:
await self.flush()
@ -475,7 +493,7 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
return self
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[PayloadT]):
'''
Ring Buffer receiver side implementation
@ -487,6 +505,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self,
token: RBToken,
cleanup: bool = True,
decoder: Decoder | None = None
):
self._token = RBToken.from_msg(token)
@ -513,6 +532,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
# set to `False` on `.open()`
self._is_closed: bool = True
self._dec: Decoder | None = decoder
# ensure no concurrent `.receive_some()` calls
self._receive_some_lock = trio.StrictFIFOLock()
@ -648,7 +669,11 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._write_ptr += delta
# yield lock and re-enter
except (EFDReadCancelled, trio.Cancelled):
except (
EFDReadCancelled, # read was cancelled with cscope
trio.Cancelled, # read got cancelled from outside
trio.BrokenResourceError # OSError EBADF happened while reading
):
# while waiting for new data `self._write_event` was closed
try:
# if eof was signaled receive no wait will not raise
@ -699,7 +724,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
return payload
async def receive(self) -> bytes:
async def receive(self, raw: bool = False) -> PayloadT:
'''
Receive a complete payload or raise EOC
@ -717,7 +742,27 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
if size == 0:
raise trio.EndOfChannel
return await self.receive_exactly(size)
raw_msg = await self.receive_exactly(size)
if raw:
return raw_msg
return (
raw_msg
if not self._dec
else self._dec.decode(raw_msg)
)
async def iter_raw_pairs(self) -> tuple[bytes, PayloadT]:
if not self._dec:
raise RuntimeError('iter_raw_pair requires decoder')
while True:
try:
raw = await self.receive(raw=True)
yield raw, self._dec.decode(raw)
except trio.EndOfChannel:
break
def open(self):
try:
@ -782,7 +827,8 @@ async def _maybe_obtain_shared_resources(token: RBToken):
async def attach_to_ringbuf_receiver(
token: RBToken,
cleanup: bool = True
cleanup: bool = True,
decoder: Decoder | None = None
) -> AsyncContextManager[RingBufferReceiveChannel]:
'''
@ -800,7 +846,8 @@ async def attach_to_ringbuf_receiver(
trio.open_nursery(strict_exception_groups=False) as n,
RingBufferReceiveChannel(
token,
cleanup=cleanup
cleanup=cleanup,
decoder=decoder
) as receiver
):
n.start_soon(receiver._eof_monitor_task)
@ -812,7 +859,8 @@ async def attach_to_ringbuf_sender(
token: RBToken,
batch_size: int = 1,
cleanup: bool = True
cleanup: bool = True,
encoder: Encoder | None = None
) -> AsyncContextManager[RingBufferSendChannel]:
'''
@ -828,7 +876,8 @@ async def attach_to_ringbuf_sender(
async with RingBufferSendChannel(
token,
batch_size=batch_size,
cleanup=cleanup
cleanup=cleanup,
encoder=encoder
) as sender:
yield sender
@ -901,6 +950,8 @@ async def attach_to_ringbuf_channel(
batch_size: int = 1,
cleanup_in: bool = True,
cleanup_out: bool = True,
encoder: Encoder | None = None,
decoder: Decoder | None = None
) -> AsyncContextManager[trio.StapledStream]:
'''
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
@ -909,12 +960,14 @@ async def attach_to_ringbuf_channel(
async with (
attach_to_ringbuf_receiver(
token_in,
cleanup=cleanup_in
cleanup=cleanup_in,
decoder=decoder
) as receiver,
attach_to_ringbuf_sender(
token_out,
batch_size=batch_size,
cleanup=cleanup_out
cleanup=cleanup_out,
encoder=encoder
) as sender,
):
yield RingBufferChannel(sender, receiver)