Add optional msgpack encoder & decoder to ringbuf apis
parent
86e09a80f4
commit
8799cf3b78
|
@ -20,6 +20,7 @@ IPC Reliable RingBuffer implementation
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import struct
|
import struct
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TypeVar,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
AsyncContextManager
|
AsyncContextManager
|
||||||
)
|
)
|
||||||
|
@ -34,6 +35,10 @@ from msgspec import (
|
||||||
Struct,
|
Struct,
|
||||||
to_builtins
|
to_builtins
|
||||||
)
|
)
|
||||||
|
from msgspec.msgpack import (
|
||||||
|
Encoder,
|
||||||
|
Decoder,
|
||||||
|
)
|
||||||
|
|
||||||
from tractor.log import get_logger
|
from tractor.log import get_logger
|
||||||
from tractor._exceptions import (
|
from tractor._exceptions import (
|
||||||
|
@ -277,7 +282,10 @@ next full payload.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
PayloadT = TypeVar('PayloadT')
|
||||||
|
|
||||||
|
|
||||||
|
class RingBufferSendChannel(trio.abc.SendChannel[PayloadT]):
|
||||||
'''
|
'''
|
||||||
Ring Buffer sender side implementation
|
Ring Buffer sender side implementation
|
||||||
|
|
||||||
|
@ -298,7 +306,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self,
|
self,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
cleanup: bool = False
|
cleanup: bool = False,
|
||||||
|
encoder: Encoder | None = None
|
||||||
):
|
):
|
||||||
self._token = RBToken.from_msg(token)
|
self._token = RBToken.from_msg(token)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -319,6 +328,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
# close shm & fds on exit?
|
# close shm & fds on exit?
|
||||||
self._cleanup: bool = cleanup
|
self._cleanup: bool = cleanup
|
||||||
|
|
||||||
|
self._enc: Encoder | None = encoder
|
||||||
|
|
||||||
# have we closed this ringbuf?
|
# have we closed this ringbuf?
|
||||||
# set to `False` on `.open()`
|
# set to `False` on `.open()`
|
||||||
self._is_closed: bool = True
|
self._is_closed: bool = True
|
||||||
|
@ -415,15 +426,22 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
if new_batch_size:
|
if new_batch_size:
|
||||||
self.batch_size = new_batch_size
|
self.batch_size = new_batch_size
|
||||||
|
|
||||||
async def send(self, value: bytes) -> None:
|
async def send(self, value: PayloadT) -> None:
|
||||||
if self.closed:
|
if self.closed:
|
||||||
raise trio.ClosedResourceError
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
if self._send_lock.locked():
|
if self._send_lock.locked():
|
||||||
raise trio.BusyResourceError
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
raw_value: bytes = (
|
||||||
|
value
|
||||||
|
if isinstance(value, bytes)
|
||||||
|
else
|
||||||
|
self._enc.encode(value)
|
||||||
|
)
|
||||||
|
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
msg: bytes = struct.pack("<I", len(value)) + value
|
msg: bytes = struct.pack("<I", len(raw_value)) + raw_value
|
||||||
if self.batch_size == 1:
|
if self.batch_size == 1:
|
||||||
if len(self._batch) > 0:
|
if len(self._batch) > 0:
|
||||||
await self.flush()
|
await self.flush()
|
||||||
|
@ -475,7 +493,7 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[PayloadT]):
|
||||||
'''
|
'''
|
||||||
Ring Buffer receiver side implementation
|
Ring Buffer receiver side implementation
|
||||||
|
|
||||||
|
@ -487,6 +505,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
self,
|
self,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True,
|
cleanup: bool = True,
|
||||||
|
decoder: Decoder | None = None
|
||||||
):
|
):
|
||||||
self._token = RBToken.from_msg(token)
|
self._token = RBToken.from_msg(token)
|
||||||
|
|
||||||
|
@ -513,6 +532,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
# set to `False` on `.open()`
|
# set to `False` on `.open()`
|
||||||
self._is_closed: bool = True
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
self._dec: Decoder | None = decoder
|
||||||
|
|
||||||
# ensure no concurrent `.receive_some()` calls
|
# ensure no concurrent `.receive_some()` calls
|
||||||
self._receive_some_lock = trio.StrictFIFOLock()
|
self._receive_some_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@ -648,7 +669,11 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
self._write_ptr += delta
|
self._write_ptr += delta
|
||||||
# yield lock and re-enter
|
# yield lock and re-enter
|
||||||
|
|
||||||
except (EFDReadCancelled, trio.Cancelled):
|
except (
|
||||||
|
EFDReadCancelled, # read was cancelled with cscope
|
||||||
|
trio.Cancelled, # read got cancelled from outside
|
||||||
|
trio.BrokenResourceError # OSError EBADF happened while reading
|
||||||
|
):
|
||||||
# while waiting for new data `self._write_event` was closed
|
# while waiting for new data `self._write_event` was closed
|
||||||
try:
|
try:
|
||||||
# if eof was signaled receive no wait will not raise
|
# if eof was signaled receive no wait will not raise
|
||||||
|
@ -699,7 +724,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self, raw: bool = False) -> PayloadT:
|
||||||
'''
|
'''
|
||||||
Receive a complete payload or raise EOC
|
Receive a complete payload or raise EOC
|
||||||
|
|
||||||
|
@ -717,7 +742,27 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
if size == 0:
|
if size == 0:
|
||||||
raise trio.EndOfChannel
|
raise trio.EndOfChannel
|
||||||
|
|
||||||
return await self.receive_exactly(size)
|
raw_msg = await self.receive_exactly(size)
|
||||||
|
if raw:
|
||||||
|
return raw_msg
|
||||||
|
|
||||||
|
return (
|
||||||
|
raw_msg
|
||||||
|
if not self._dec
|
||||||
|
else self._dec.decode(raw_msg)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iter_raw_pairs(self) -> tuple[bytes, PayloadT]:
|
||||||
|
if not self._dec:
|
||||||
|
raise RuntimeError('iter_raw_pair requires decoder')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
raw = await self.receive(raw=True)
|
||||||
|
yield raw, self._dec.decode(raw)
|
||||||
|
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
break
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
try:
|
try:
|
||||||
|
@ -782,7 +827,8 @@ async def _maybe_obtain_shared_resources(token: RBToken):
|
||||||
async def attach_to_ringbuf_receiver(
|
async def attach_to_ringbuf_receiver(
|
||||||
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True
|
cleanup: bool = True,
|
||||||
|
decoder: Decoder | None = None
|
||||||
|
|
||||||
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
||||||
'''
|
'''
|
||||||
|
@ -800,7 +846,8 @@ async def attach_to_ringbuf_receiver(
|
||||||
trio.open_nursery(strict_exception_groups=False) as n,
|
trio.open_nursery(strict_exception_groups=False) as n,
|
||||||
RingBufferReceiveChannel(
|
RingBufferReceiveChannel(
|
||||||
token,
|
token,
|
||||||
cleanup=cleanup
|
cleanup=cleanup,
|
||||||
|
decoder=decoder
|
||||||
) as receiver
|
) as receiver
|
||||||
):
|
):
|
||||||
n.start_soon(receiver._eof_monitor_task)
|
n.start_soon(receiver._eof_monitor_task)
|
||||||
|
@ -812,7 +859,8 @@ async def attach_to_ringbuf_sender(
|
||||||
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
cleanup: bool = True
|
cleanup: bool = True,
|
||||||
|
encoder: Encoder | None = None
|
||||||
|
|
||||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||||
'''
|
'''
|
||||||
|
@ -828,7 +876,8 @@ async def attach_to_ringbuf_sender(
|
||||||
async with RingBufferSendChannel(
|
async with RingBufferSendChannel(
|
||||||
token,
|
token,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
cleanup=cleanup
|
cleanup=cleanup,
|
||||||
|
encoder=encoder
|
||||||
) as sender:
|
) as sender:
|
||||||
yield sender
|
yield sender
|
||||||
|
|
||||||
|
@ -901,6 +950,8 @@ async def attach_to_ringbuf_channel(
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
cleanup_in: bool = True,
|
cleanup_in: bool = True,
|
||||||
cleanup_out: bool = True,
|
cleanup_out: bool = True,
|
||||||
|
encoder: Encoder | None = None,
|
||||||
|
decoder: Decoder | None = None
|
||||||
) -> AsyncContextManager[trio.StapledStream]:
|
) -> AsyncContextManager[trio.StapledStream]:
|
||||||
'''
|
'''
|
||||||
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
||||||
|
@ -909,12 +960,14 @@ async def attach_to_ringbuf_channel(
|
||||||
async with (
|
async with (
|
||||||
attach_to_ringbuf_receiver(
|
attach_to_ringbuf_receiver(
|
||||||
token_in,
|
token_in,
|
||||||
cleanup=cleanup_in
|
cleanup=cleanup_in,
|
||||||
|
decoder=decoder
|
||||||
) as receiver,
|
) as receiver,
|
||||||
attach_to_ringbuf_sender(
|
attach_to_ringbuf_sender(
|
||||||
token_out,
|
token_out,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
cleanup=cleanup_out
|
cleanup=cleanup_out,
|
||||||
|
encoder=encoder
|
||||||
) as sender,
|
) as sender,
|
||||||
):
|
):
|
||||||
yield RingBufferChannel(sender, receiver)
|
yield RingBufferChannel(sender, receiver)
|
||||||
|
|
Loading…
Reference in New Issue