Add optional msgpack encoder & decoder to ringbuf apis
parent
86e09a80f4
commit
8799cf3b78
|
@ -20,6 +20,7 @@ IPC Reliable RingBuffer implementation
|
|||
from __future__ import annotations
|
||||
import struct
|
||||
from typing import (
|
||||
TypeVar,
|
||||
ContextManager,
|
||||
AsyncContextManager
|
||||
)
|
||||
|
@ -34,6 +35,10 @@ from msgspec import (
|
|||
Struct,
|
||||
to_builtins
|
||||
)
|
||||
from msgspec.msgpack import (
|
||||
Encoder,
|
||||
Decoder,
|
||||
)
|
||||
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
|
@ -277,7 +282,10 @@ next full payload.
|
|||
'''
|
||||
|
||||
|
||||
class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||
PayloadT = TypeVar('PayloadT')
|
||||
|
||||
|
||||
class RingBufferSendChannel(trio.abc.SendChannel[PayloadT]):
|
||||
'''
|
||||
Ring Buffer sender side implementation
|
||||
|
||||
|
@ -298,7 +306,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self,
|
||||
token: RBToken,
|
||||
batch_size: int = 1,
|
||||
cleanup: bool = False
|
||||
cleanup: bool = False,
|
||||
encoder: Encoder | None = None
|
||||
):
|
||||
self._token = RBToken.from_msg(token)
|
||||
self.batch_size = batch_size
|
||||
|
@ -319,6 +328,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
# close shm & fds on exit?
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
self._enc: Encoder | None = encoder
|
||||
|
||||
# have we closed this ringbuf?
|
||||
# set to `False` on `.open()`
|
||||
self._is_closed: bool = True
|
||||
|
@ -415,15 +426,22 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
if new_batch_size:
|
||||
self.batch_size = new_batch_size
|
||||
|
||||
async def send(self, value: bytes) -> None:
|
||||
async def send(self, value: PayloadT) -> None:
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
if self._send_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
raw_value: bytes = (
|
||||
value
|
||||
if isinstance(value, bytes)
|
||||
else
|
||||
self._enc.encode(value)
|
||||
)
|
||||
|
||||
async with self._send_lock:
|
||||
msg: bytes = struct.pack("<I", len(value)) + value
|
||||
msg: bytes = struct.pack("<I", len(raw_value)) + raw_value
|
||||
if self.batch_size == 1:
|
||||
if len(self._batch) > 0:
|
||||
await self.flush()
|
||||
|
@ -475,7 +493,7 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
return self
|
||||
|
||||
|
||||
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[PayloadT]):
|
||||
'''
|
||||
Ring Buffer receiver side implementation
|
||||
|
||||
|
@ -487,6 +505,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
self,
|
||||
token: RBToken,
|
||||
cleanup: bool = True,
|
||||
decoder: Decoder | None = None
|
||||
):
|
||||
self._token = RBToken.from_msg(token)
|
||||
|
||||
|
@ -513,6 +532,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
# set to `False` on `.open()`
|
||||
self._is_closed: bool = True
|
||||
|
||||
self._dec: Decoder | None = decoder
|
||||
|
||||
# ensure no concurrent `.receive_some()` calls
|
||||
self._receive_some_lock = trio.StrictFIFOLock()
|
||||
|
||||
|
@ -648,7 +669,11 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
self._write_ptr += delta
|
||||
# yield lock and re-enter
|
||||
|
||||
except (EFDReadCancelled, trio.Cancelled):
|
||||
except (
|
||||
EFDReadCancelled, # read was cancelled with cscope
|
||||
trio.Cancelled, # read got cancelled from outside
|
||||
trio.BrokenResourceError # OSError EBADF happened while reading
|
||||
):
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
try:
|
||||
# if eof was signaled receive no wait will not raise
|
||||
|
@ -699,7 +724,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
return payload
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
async def receive(self, raw: bool = False) -> PayloadT:
|
||||
'''
|
||||
Receive a complete payload or raise EOC
|
||||
|
||||
|
@ -717,7 +742,27 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
if size == 0:
|
||||
raise trio.EndOfChannel
|
||||
|
||||
return await self.receive_exactly(size)
|
||||
raw_msg = await self.receive_exactly(size)
|
||||
if raw:
|
||||
return raw_msg
|
||||
|
||||
return (
|
||||
raw_msg
|
||||
if not self._dec
|
||||
else self._dec.decode(raw_msg)
|
||||
)
|
||||
|
||||
async def iter_raw_pairs(self) -> tuple[bytes, PayloadT]:
|
||||
if not self._dec:
|
||||
raise RuntimeError('iter_raw_pair requires decoder')
|
||||
|
||||
while True:
|
||||
try:
|
||||
raw = await self.receive(raw=True)
|
||||
yield raw, self._dec.decode(raw)
|
||||
|
||||
except trio.EndOfChannel:
|
||||
break
|
||||
|
||||
def open(self):
|
||||
try:
|
||||
|
@ -782,7 +827,8 @@ async def _maybe_obtain_shared_resources(token: RBToken):
|
|||
async def attach_to_ringbuf_receiver(
|
||||
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
cleanup: bool = True,
|
||||
decoder: Decoder | None = None
|
||||
|
||||
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
||||
'''
|
||||
|
@ -800,7 +846,8 @@ async def attach_to_ringbuf_receiver(
|
|||
trio.open_nursery(strict_exception_groups=False) as n,
|
||||
RingBufferReceiveChannel(
|
||||
token,
|
||||
cleanup=cleanup
|
||||
cleanup=cleanup,
|
||||
decoder=decoder
|
||||
) as receiver
|
||||
):
|
||||
n.start_soon(receiver._eof_monitor_task)
|
||||
|
@ -812,7 +859,8 @@ async def attach_to_ringbuf_sender(
|
|||
|
||||
token: RBToken,
|
||||
batch_size: int = 1,
|
||||
cleanup: bool = True
|
||||
cleanup: bool = True,
|
||||
encoder: Encoder | None = None
|
||||
|
||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||
'''
|
||||
|
@ -828,7 +876,8 @@ async def attach_to_ringbuf_sender(
|
|||
async with RingBufferSendChannel(
|
||||
token,
|
||||
batch_size=batch_size,
|
||||
cleanup=cleanup
|
||||
cleanup=cleanup,
|
||||
encoder=encoder
|
||||
) as sender:
|
||||
yield sender
|
||||
|
||||
|
@ -901,6 +950,8 @@ async def attach_to_ringbuf_channel(
|
|||
batch_size: int = 1,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True,
|
||||
encoder: Encoder | None = None,
|
||||
decoder: Decoder | None = None
|
||||
) -> AsyncContextManager[trio.StapledStream]:
|
||||
'''
|
||||
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
||||
|
@ -909,12 +960,14 @@ async def attach_to_ringbuf_channel(
|
|||
async with (
|
||||
attach_to_ringbuf_receiver(
|
||||
token_in,
|
||||
cleanup=cleanup_in
|
||||
cleanup=cleanup_in,
|
||||
decoder=decoder
|
||||
) as receiver,
|
||||
attach_to_ringbuf_sender(
|
||||
token_out,
|
||||
batch_size=batch_size,
|
||||
cleanup=cleanup_out
|
||||
cleanup=cleanup_out,
|
||||
encoder=encoder
|
||||
) as sender,
|
||||
):
|
||||
yield RingBufferChannel(sender, receiver)
|
||||
|
|
Loading…
Reference in New Issue