Compare commits

...

11 Commits

Author SHA1 Message Date
Guillermo Rodriguez efd11f7d74
Trying to make full suite pass with uds 2025-03-27 20:37:52 -03:00
Guillermo Rodriguez 76cee99fc2
Finally switch to using address protocol in all runtime 2025-03-27 20:37:52 -03:00
Guillermo Rodriguez 5f50206d84
Add root and random addr getters on MsgTransport type 2025-03-27 20:37:52 -03:00
Guillermo Rodriguez a47a7a39b1
Starting to make tractor.ipc.Channel work with multiple MsgTransports 2025-03-27 20:37:52 -03:00
Guillermo Rodriguez bab265b2d8
Important RingBuffBytesSender fix on non batched mode! & downgrade nix-shell python to lowest supported 2025-03-27 20:36:46 -03:00
Guillermo Rodriguez 010874bed5
Catch trio cancellation on RingBuffReceiver bg eof listener task, add batched mode to RingBuffBytesSender 2025-03-27 20:36:46 -03:00
Guillermo Rodriguez ea010ab46a
Add direct read method on EventFD
Type hint all ctx managers in _ringbuf.py
Remove unnecesary send lock on ring chan sender
Handle EOF on ring chan receiver
Rename ringbuf tests to make it less redundant
2025-03-27 20:36:46 -03:00
Guillermo Rodriguez be7fc89ae9
Add direct ctx managers for RB channels 2025-03-27 20:36:46 -03:00
Guillermo Rodriguez 2a9a78651b
Improve test_ringbuf test, drop MsgTransport ring buf impl for now in favour of a trio.abc.Channel[bytes] impl, add docstrings 2025-03-27 20:36:46 -03:00
Guillermo Rodriguez be818a720a
Switch `tractor.ipc.MsgTransport.stream` type to `trio.abc.Stream`
Add EOF signaling mechanism
Support proper `receive_some` end of stream semantics
Add StapledStream non-ipc test
Create MsgpackRBStream similar to MsgpackTCPStream for buffered whole-msg reads
Add EventFD.read cancellation on EventFD.close mechanism using cancel scope
Add test for eventfd cancellation
Improve and add docstrings
2025-03-27 20:36:46 -03:00
Guillermo Rodriguez ba353bf46f
Better encapsulate RingBuff ctx managment methods and support non ipc usage
Add trio.StrictFIFOLock on sender.send_all
Support max_bytes argument on receive_some, keep track of write_ptr on receiver
Add max_bytes receive test test_ringbuf_max_bytes
Add docstrings to all ringbuf tests
Remove EFD_NONBLOCK support, not necesary anymore since we can use abandon_on_cancel=True on trio.to_thread.run_sync
Close eventfd's after usage on open_ringbuf
2025-03-27 20:36:46 -03:00
29 changed files with 2101 additions and 819 deletions

View File

@ -10,9 +10,10 @@ pkgs.mkShell {
inherit nativeBuildInputs; inherit nativeBuildInputs;
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs; LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
TMPDIR = "/tmp";
shellHook = '' shellHook = ''
set -e set -e
uv venv .venv --python=3.12 uv venv .venv --python=3.11
''; '';
} }

View File

@ -9,7 +9,7 @@ async def main(service_name):
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
await an.start_actor(service_name) await an.start_actor(service_name)
async with tractor.get_registry('127.0.0.1', 1616) as portal: async with tractor.get_registry() as portal:
print(f"Arbiter is listening on {portal.channel}") print(f"Arbiter is listening on {portal.channel}")
async with tractor.wait_for_actor(service_name) as sockaddr: async with tractor.wait_for_actor(service_name) as sockaddr:

View File

@ -26,7 +26,7 @@ async def test_reg_then_unreg(reg_addr):
portal = await n.start_actor('actor', enable_modules=[__name__]) portal = await n.start_actor('actor', enable_modules=[__name__])
uid = portal.channel.uid uid = portal.channel.uid
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
# this local actor should be the arbiter # this local actor should be the arbiter
assert actor is aportal.actor assert actor is aportal.actor
@ -160,7 +160,7 @@ async def spawn_and_check_registry(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
# runtime needs to be up to call this # runtime needs to be up to call this
actor = tractor.current_actor() actor = tractor.current_actor()
@ -300,7 +300,7 @@ async def close_chans_before_nursery(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
try: try:
get_reg = partial(unpack_reg, aportal) get_reg = partial(unpack_reg, aportal)

View File

@ -66,6 +66,9 @@ def run_example_in_subproc(
# due to backpressure!!! # due to backpressure!!!
proc = testdir.popen( proc = testdir.popen(
cmdargs, cmdargs,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
**kwargs, **kwargs,
) )
assert not proc.returncode assert not proc.returncode
@ -119,10 +122,14 @@ def test_example(
code = ex.read() code = ex.read()
with run_example_in_subproc(code) as proc: with run_example_in_subproc(code) as proc:
proc.wait() err = None
err, _ = proc.stderr.read(), proc.stdout.read() try:
# print(f'STDERR: {err}') if not proc.poll():
# print(f'STDOUT: {out}') _, err = proc.communicate(timeout=15)
except subprocess.TimeoutExpired as e:
proc.kill()
err = e.stderr
# if we get some gnarly output let's aggregate and raise # if we get some gnarly output let's aggregate and raise
if err: if err:

View File

@ -0,0 +1,32 @@
import trio
import pytest
from tractor.ipc import (
open_eventfd,
EFDReadCancelled,
EventFD
)
def test_eventfd_read_cancellation():
'''
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
is called.
'''
fd = open_eventfd()
async def _read(event: EventFD):
with pytest.raises(EFDReadCancelled):
await event.read()
async def main():
async with trio.open_nursery() as n:
with (
EventFD(fd, 'w') as event,
trio.fail_after(3)
):
n.start_soon(_read, event)
await trio.sleep(0.2)
event.close()
trio.run(main)

View File

@ -871,7 +871,7 @@ async def serve_subactors(
) )
await ipc.send(( await ipc.send((
peer.chan.uid, peer.chan.uid,
peer.chan.raddr, peer.chan.raddr.unwrap(),
)) ))
print('Spawner exiting spawn serve loop!') print('Spawner exiting spawn serve loop!')

View File

@ -38,7 +38,7 @@ async def test_self_is_registered_localportal(reg_addr):
"Verify waiting on the arbiter to register itself using a local portal." "Verify waiting on the arbiter to register itself using a local portal."
actor = tractor.current_actor() actor = tractor.current_actor()
assert actor.is_arbiter assert actor.is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
assert isinstance(portal, tractor._portal.LocalPortal) assert isinstance(portal, tractor._portal.LocalPortal)
with trio.fail_after(0.2): with trio.fail_after(0.2):

View File

@ -32,7 +32,7 @@ def test_abort_on_sigint(daemon):
@tractor_test @tractor_test
async def test_cancel_remote_arbiter(daemon, reg_addr): async def test_cancel_remote_arbiter(daemon, reg_addr):
assert not tractor.current_actor().is_arbiter assert not tractor.current_actor().is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
await portal.cancel_actor() await portal.cancel_actor()
time.sleep(0.1) time.sleep(0.1)
@ -41,7 +41,7 @@ async def test_cancel_remote_arbiter(daemon, reg_addr):
# no arbiter socket should exist # no arbiter socket should exist
with pytest.raises(OSError): with pytest.raises(OSError):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
pass pass

View File

@ -1,15 +1,21 @@
import time import time
import hashlib
import trio import trio
import pytest import pytest
import tractor import tractor
from tractor.ipc import ( from tractor.ipc import (
open_ringbuf, open_ringbuf,
attach_to_ringbuf_receiver,
attach_to_ringbuf_sender,
attach_to_ringbuf_stream,
attach_to_ringbuf_channel,
RBToken, RBToken,
RingBuffSender,
RingBuffReceiver
) )
from tractor._testing.samples import generate_sample_messages from tractor._testing.samples import (
generate_single_byte_msgs,
generate_sample_messages
)
@tractor.context @tractor.context
@ -17,19 +23,27 @@ async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int, msg_amount: int,
token: RBToken, token: RBToken,
total_bytes: int, ) -> str:
) -> None: '''
recvd_bytes = 0 Sub-actor used in `test_ringbuf`.
await ctx.started()
start_ts = time.time()
async with RingBuffReceiver(token) as receiver:
while recvd_bytes < total_bytes:
msg = await receiver.receive_some()
recvd_bytes += len(msg)
# make sure we dont hold any memoryviews Attach to a ringbuf and receive all messages until end of stream.
# before the ctx manager aclose() Keep track of how many bytes received and also calculate
msg = None sha256 of the whole byte stream.
Calculate and print performance stats, finally return calculated
hash.
'''
await ctx.started()
print('reader started')
recvd_bytes = 0
recvd_hash = hashlib.sha256()
start_ts = time.time()
async with attach_to_ringbuf_receiver(token) as receiver:
async for msg in receiver:
recvd_hash.update(msg)
recvd_bytes += len(msg)
end_ts = time.time() end_ts = time.time()
elapsed = end_ts - start_ts elapsed = end_ts - start_ts
@ -38,6 +52,9 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}') print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest()
@tractor.context @tractor.context
@ -48,16 +65,32 @@ async def child_write_shm(
rand_max: int, rand_max: int,
token: RBToken, token: RBToken,
) -> None: ) -> None:
msgs, total_bytes = generate_sample_messages( '''
Sub-actor used in `test_ringbuf`
Generate `msg_amount` payloads with
`random.randint(rand_min, rand_max)` random bytes at the end,
Calculate sha256 hash and send it to parent on `ctx.started`.
Attach to ringbuf and send all generated messages.
'''
msgs, _total_bytes = generate_sample_messages(
msg_amount, msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) )
await ctx.started(total_bytes) print('writer hashing payload...')
async with RingBuffSender(token) as sender: sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
print('writer done hashing.')
await ctx.started(sent_hash)
print('writer started')
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in msgs:
await sender.send_all(msg) await sender.send_all(msg)
print('writer exit')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'msg_amount,rand_min,rand_max,buf_size', 'msg_amount,rand_min,rand_max,buf_size',
@ -83,19 +116,23 @@ def test_ringbuf(
rand_max: int, rand_max: int,
buf_size: int buf_size: int
): ):
'''
- Open a new ring buf on root actor
- Open `child_write_shm` ctx in sub-actor which will generate a
random payload and send its hash on `ctx.started`, finally sending
the payload through the stream.
- Open `child_read_shm` ctx in sub-actor which will receive the
payload, calculate perf stats and return the hash.
- Compare both hashes
'''
async def main(): async def main():
with open_ringbuf( with open_ringbuf(
'test_ringbuf', 'test_ringbuf',
buf_size=buf_size buf_size=buf_size
) as token: ) as token:
proc_kwargs = { proc_kwargs = {'pass_fds': token.fds}
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
common_kwargs = {
'msg_amount': msg_amount,
'token': token,
}
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
send_p = await an.start_actor( send_p = await an.start_actor(
'ring_sender', 'ring_sender',
@ -110,17 +147,20 @@ def test_ringbuf(
async with ( async with (
send_p.open_context( send_p.open_context(
child_write_shm, child_write_shm,
token=token,
msg_amount=msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
**common_kwargs ) as (_sctx, sent_hash),
) as (sctx, total_bytes),
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
**common_kwargs, token=token,
total_bytes=total_bytes, msg_amount=msg_amount
) as (sctx, _sent), ) as (rctx, _sent),
): ):
await recv_p.result() recvd_hash = await rctx.result()
assert sent_hash == recvd_hash
await send_p.cancel_actor() await send_p.cancel_actor()
await recv_p.cancel_actor() await recv_p.cancel_actor()
@ -134,23 +174,28 @@ async def child_blocked_receiver(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with RingBuffReceiver(token) as receiver: async with attach_to_ringbuf_receiver(token) as receiver:
await ctx.started() await ctx.started()
await receiver.receive_some() await receiver.receive_some()
def test_ring_reader_cancel(): def test_reader_cancel():
'''
Test that a receiver blocked on eventfd(2) read responds to
cancellation.
'''
async def main(): async def main():
with open_ringbuf('test_ring_cancel_reader') as token: with open_ringbuf('test_ring_cancel_reader') as token:
async with ( async with (
tractor.open_nursery() as an, tractor.open_nursery() as an,
RingBuffSender(token) as _sender, attach_to_ringbuf_sender(token) as _sender,
): ):
recv_p = await an.start_actor( recv_p = await an.start_actor(
'ring_blocked_receiver', 'ring_blocked_receiver',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd) 'pass_fds': token.fds
} }
) )
async with ( async with (
@ -172,12 +217,17 @@ async def child_blocked_sender(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with RingBuffSender(token) as sender: async with attach_to_ringbuf_sender(token) as sender:
await ctx.started() await ctx.started()
await sender.send_all(b'this will wrap') await sender.send_all(b'this will wrap')
def test_ring_sender_cancel(): def test_sender_cancel():
'''
Test that a sender blocked on eventfd(2) read responds to
cancellation.
'''
async def main(): async def main():
with open_ringbuf( with open_ringbuf(
'test_ring_cancel_sender', 'test_ring_cancel_sender',
@ -188,7 +238,7 @@ def test_ring_sender_cancel():
'ring_blocked_sender', 'ring_blocked_sender',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd) 'pass_fds': token.fds
} }
) )
async with ( async with (
@ -203,3 +253,171 @@ def test_ring_sender_cancel():
with pytest.raises(tractor._exceptions.ContextCancelled): with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main) trio.run(main)
def test_receiver_max_bytes():
'''
Test that RingBuffReceiver.receive_some's max_bytes optional
argument works correctly, send a msg of size 100, then
force receive of messages with max_bytes == 1, wait until
100 of these messages are received, then compare join of
msgs with original message
'''
msg = generate_single_byte_msgs(100)
msgs = []
async def main():
with open_ringbuf(
'test_ringbuf_max_bytes',
buf_size=10
) as token:
async with (
trio.open_nursery() as n,
attach_to_ringbuf_sender(token, cleanup=False) as sender,
attach_to_ringbuf_receiver(token, cleanup=False) as receiver
):
async def _send_and_close():
await sender.send_all(msg)
await sender.aclose()
n.start_soon(_send_and_close)
while len(msgs) < len(msg):
msg_part = await receiver.receive_some(max_bytes=1)
assert len(msg_part) == 1
msgs.append(msg_part)
trio.run(main)
assert msg == b''.join(msgs)
def test_stapled_ringbuf():
'''
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
are inversed on each task) which will open the streams and use trio.StapledStream
to have a single bidirectional stream.
Then take turns to send and receive messages.
'''
msg = generate_single_byte_msgs(100)
pair_0_msgs = []
pair_1_msgs = []
pair_0_done = trio.Event()
pair_1_done = trio.Event()
async def pair_0(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to send
await stream.send_all(msg)
# second turn to receive
while len(pair_0_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_0_msgs.append(_msg)
pair_0_done.set()
await pair_1_done.wait()
async def pair_1(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to receive
while len(pair_1_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_1_msgs.append(_msg)
# second turn to send
await stream.send_all(msg)
pair_1_done.set()
await pair_0_done.wait()
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_stapled_ringbuf'
) as (token_0, token_1):
async with trio.open_nursery() as n:
n.start_soon(pair_0, token_0, token_1)
n.start_soon(pair_1, token_1, token_0)
trio.run(main)
assert msg == b''.join(pair_0_msgs)
assert msg == b''.join(pair_1_msgs)
@tractor.context
async def child_channel_sender(
ctx: tractor.Context,
msg_amount_min: int,
msg_amount_max: int,
token_in: RBToken,
token_out: RBToken
):
import random
msgs, _total_bytes = generate_sample_messages(
random.randint(msg_amount_min, msg_amount_max),
rand_min=256,
rand_max=1024,
)
async with attach_to_ringbuf_channel(
token_in,
token_out
) as chan:
await ctx.started(msgs)
for msg in msgs:
await chan.send(msg)
def test_channel():
msg_amount_min = 100
msg_amount_max = 1000
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_ringbuf_transport'
) as (token_0, token_1):
async with (
attach_to_ringbuf_channel(token_0, token_1) as chan,
tractor.open_nursery() as an
):
recv_p = await an.start_actor(
'test_ringbuf_transport_sender',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': token_0.fds + token_1.fds
}
)
async with (
recv_p.open_context(
child_channel_sender,
msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max,
token_in=token_1,
token_out=token_0
) as (ctx, msgs),
):
recv_msgs = []
async for msg in chan:
recv_msgs.append(msg)
await recv_p.cancel_actor()
assert recv_msgs == msgs
trio.run(main)

View File

@ -77,7 +77,7 @@ async def movie_theatre_question():
async def test_movie_theatre_convo(start_method): async def test_movie_theatre_convo(start_method):
"""The main ``tractor`` routine. """The main ``tractor`` routine.
""" """
async with tractor.open_nursery() as n: async with tractor.open_nursery(debug_mode=True) as n:
portal = await n.start_actor( portal = await n.start_actor(
'frank', 'frank',

310
tractor/_addr.py 100644
View File

@ -0,0 +1,310 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import os
import tempfile
from uuid import uuid4
from typing import (
Protocol,
ClassVar,
TypeVar,
Union,
Type
)
import trio
from trio import socket
NamespaceType = TypeVar('NamespaceType')
AddressType = TypeVar('AddressType')
StreamType = TypeVar('StreamType')
ListenerType = TypeVar('ListenerType')
class Address(Protocol[
NamespaceType,
AddressType,
StreamType,
ListenerType
]):
name_key: ClassVar[str]
address_type: ClassVar[Type[AddressType]]
@property
def is_valid(self) -> bool:
...
@property
def namespace(self) -> NamespaceType|None:
...
@classmethod
def from_addr(cls, addr: AddressType) -> Address:
...
def unwrap(self) -> AddressType:
...
@classmethod
def get_random(cls, namespace: NamespaceType | None = None) -> Address:
...
@classmethod
def get_root(cls) -> Address:
...
def __repr__(self) -> str:
...
def __eq__(self, other) -> bool:
...
async def open_stream(self, **kwargs) -> StreamType:
...
async def open_listener(self, **kwargs) -> ListenerType:
...
async def close_listener(self):
...
class TCPAddress(Address[
str,
tuple[str, int],
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'tcp'
address_type: type = tuple[str, int]
def __init__(
self,
host: str,
port: int
):
if (
not isinstance(host, str)
or
not isinstance(port, int)
):
raise TypeError(f'Expected host {host} to be str and port {port} to be int')
self._host = host
self._port = port
@property
def is_valid(self) -> bool:
return self._port != 0
@property
def namespace(self) -> str:
return self._host
@classmethod
def from_addr(cls, addr: tuple[str, int]) -> TCPAddress:
return TCPAddress(addr[0], addr[1])
def unwrap(self) -> tuple[str, int]:
return self._host, self._port
@classmethod
def get_random(cls, namespace: str = '127.0.0.1') -> TCPAddress:
return TCPAddress(namespace, 0)
@classmethod
def get_root(cls) -> Address:
return TCPAddress('127.0.0.1', 1616)
def __repr__(self) -> str:
return f'{type(self)} @ {self.unwrap()}'
def __eq__(self, other) -> bool:
if not isinstance(other, TCPAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return (
self._host == other._host
and
self._port == other._port
)
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream(
self._host,
self._port,
**kwargs
)
self._host, self._port = stream.socket.getsockname()[:2]
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
listeners = await trio.open_tcp_listeners(
host=self._host,
port=self._port,
**kwargs
)
assert len(listeners) == 1
listener = listeners[0]
self._host, self._port = listener.socket.getsockname()[:2]
return listener
async def close_listener(self):
...
class UDSAddress(Address[
None,
str,
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'uds'
address_type: type = str
def __init__(
self,
filepath: str
):
self._filepath = filepath
@property
def is_valid(self) -> bool:
return True
@property
def namespace(self) -> None:
return
@classmethod
def from_addr(cls, filepath: str) -> UDSAddress:
return UDSAddress(filepath)
def unwrap(self) -> str:
return self._filepath
@classmethod
def get_random(cls, namespace: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
@classmethod
def get_root(cls) -> Address:
return UDSAddress('tractor.sock')
def __repr__(self) -> str:
return f'{type(self)} @ {self._filepath}'
def __eq__(self, other) -> bool:
if not isinstance(other, UDSAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return self._filepath == other._filepath
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_unix_socket(
self._filepath,
**kwargs
)
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await self._sock.bind(self._filepath)
self._sock.listen(1)
return trio.SocketListener(self._sock)
async def close_listener(self):
self._sock.close()
os.unlink(self._filepath)
preferred_transport = 'uds'
_address_types = (
TCPAddress,
UDSAddress
)
_default_addrs: dict[str, Type[Address]] = {
cls.name_key: cls
for cls in _address_types
}
AddressTypes = Union[
tuple([
cls.address_type
for cls in _address_types
])
]
_default_lo_addrs: dict[
str,
AddressTypes
] = {
cls.name_key: cls.get_root().unwrap()
for cls in _address_types
}
def get_address_cls(name: str) -> Type[Address]:
return _default_addrs[name]
def is_wrapped_addr(addr: any) -> bool:
return type(addr) in _address_types
def wrap_address(addr: AddressTypes) -> Address:
if is_wrapped_addr(addr):
return addr
cls = None
match addr:
case str():
cls = UDSAddress
case tuple() | list():
cls = TCPAddress
case None:
cls = get_address_cls(preferred_transport)
addr = cls.get_root().unwrap()
case _:
raise TypeError(
f'Can not wrap addr {addr} of type {type(addr)}'
)
return cls.from_addr(addr)
def default_lo_addrs(transports: list[str]) -> list[AddressTypes]:
return [
_default_lo_addrs[transport]
for transport in transports
]

View File

@ -31,8 +31,12 @@ def parse_uid(arg):
return str(name), str(uuid) # ensures str encoding return str(name), str(uuid) # ensures str encoding
def parse_ipaddr(arg): def parse_ipaddr(arg):
host, port = literal_eval(arg) try:
return (str(host), int(port)) return literal_eval(arg)
except (ValueError, SyntaxError):
# UDS: try to interpret as a straight up str
return arg
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -859,19 +859,10 @@ class Context:
@property @property
def dst_maddr(self) -> str: def dst_maddr(self) -> str:
chan: Channel = self.chan chan: Channel = self.chan
dst_addr, dst_port = chan.raddr
trans: MsgTransport = chan.transport trans: MsgTransport = chan.transport
# cid: str = self.cid # cid: str = self.cid
# cid_head, cid_tail = cid[:6], cid[-6:] # cid_head, cid_tail = cid[:6], cid[-6:]
return ( return trans.maddr
f'/ipv4/{dst_addr}'
f'/{trans.name_key}/{dst_port}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
dmaddr = dst_maddr dmaddr = dst_maddr

View File

@ -30,6 +30,12 @@ from contextlib import asynccontextmanager as acm
from tractor.log import get_logger from tractor.log import get_logger
from .trionics import gather_contexts from .trionics import gather_contexts
from .ipc import _connect_chan, Channel from .ipc import _connect_chan, Channel
from ._addr import (
AddressTypes,
Address,
preferred_transport,
wrap_address
)
from ._portal import ( from ._portal import (
Portal, Portal,
open_portal, open_portal,
@ -48,11 +54,7 @@ log = get_logger(__name__)
@acm @acm
async def get_registry( async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
host: str,
port: int,
) -> AsyncGenerator[
Portal | LocalPortal | None, Portal | LocalPortal | None,
None, None,
]: ]:
@ -69,13 +71,13 @@ async def get_registry(
# (likely a re-entrant call from the arbiter actor) # (likely a re-entrant call from the arbiter actor)
yield LocalPortal( yield LocalPortal(
actor, actor,
Channel((host, port)) await Channel.from_addr(addr)
) )
else: else:
# TODO: try to look pre-existing connection from # TODO: try to look pre-existing connection from
# `Actor._peers` and use it instead? # `Actor._peers` and use it instead?
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan(addr) as chan,
open_portal(chan) as regstr_ptl, open_portal(chan) as regstr_ptl,
): ):
yield regstr_ptl yield regstr_ptl
@ -89,11 +91,10 @@ async def get_root(
# TODO: rename mailbox to `_root_maddr` when we finally # TODO: rename mailbox to `_root_maddr` when we finally
# add and impl libp2p multi-addrs? # add and impl libp2p multi-addrs?
host, port = _runtime_vars['_root_mailbox'] addr = _runtime_vars['_root_mailbox']
assert host is not None
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan(addr) as chan,
open_portal(chan, **kwargs) as portal, open_portal(chan, **kwargs) as portal,
): ):
yield portal yield portal
@ -134,10 +135,10 @@ def get_peer_by_name(
@acm @acm
async def query_actor( async def query_actor(
name: str, name: str,
regaddr: tuple[str, int]|None = None, regaddr: AddressTypes|None = None,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[str, int]|None, AddressTypes|None,
None, None,
]: ]:
''' '''
@ -163,31 +164,31 @@ async def query_actor(
return return
reg_portal: Portal reg_portal: Portal
regaddr: tuple[str, int] = regaddr or actor.reg_addrs[0] regaddr: Address = wrap_address(regaddr) or actor.reg_addrs[0]
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
# TODO: return portals to all available actors - for now # TODO: return portals to all available actors - for now
# just the last one that registered # just the last one that registered
sockaddr: tuple[str, int] = await reg_portal.run_from_ns( addr: AddressTypes = await reg_portal.run_from_ns(
'self', 'self',
'find_actor', 'find_actor',
name=name, name=name,
) )
yield sockaddr yield addr
@acm @acm
async def maybe_open_portal( async def maybe_open_portal(
addr: tuple[str, int], addr: AddressTypes,
name: str, name: str,
): ):
async with query_actor( async with query_actor(
name=name, name=name,
regaddr=addr, regaddr=addr,
) as sockaddr: ) as addr:
pass pass
if sockaddr: if addr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal
else: else:
@ -197,7 +198,8 @@ async def maybe_open_portal(
@acm @acm
async def find_actor( async def find_actor(
name: str, name: str,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
only_first: bool = True, only_first: bool = True,
raise_on_none: bool = False, raise_on_none: bool = False,
@ -224,15 +226,15 @@ async def find_actor(
# XXX NOTE: make sure to dynamically read the value on # XXX NOTE: make sure to dynamically read the value on
# every call since something may change it globally (eg. # every call since something may change it globally (eg.
# like in our discovery test suite)! # like in our discovery test suite)!
from . import _root from ._addr import default_lo_addrs
registry_addrs = ( registry_addrs = (
_runtime_vars['_registry_addrs'] _runtime_vars['_registry_addrs']
or or
_root._default_lo_addrs default_lo_addrs(enable_transports)
) )
maybe_portals: list[ maybe_portals: list[
AsyncContextManager[tuple[str, int]] AsyncContextManager[AddressTypes]
] = list( ] = list(
maybe_open_portal( maybe_open_portal(
addr=addr, addr=addr,
@ -274,7 +276,7 @@ async def find_actor(
@acm @acm
async def wait_for_actor( async def wait_for_actor(
name: str, name: str,
registry_addr: tuple[str, int] | None = None, registry_addr: AddressTypes | None = None,
) -> AsyncGenerator[Portal, None]: ) -> AsyncGenerator[Portal, None]:
''' '''
@ -291,7 +293,7 @@ async def wait_for_actor(
yield peer_portal yield peer_portal
return return
regaddr: tuple[str, int] = ( regaddr: AddressTypes = (
registry_addr registry_addr
or or
actor.reg_addrs[0] actor.reg_addrs[0]
@ -299,8 +301,8 @@ async def wait_for_actor(
# TODO: use `.trionics.gather_contexts()` like # TODO: use `.trionics.gather_contexts()` like
# above in `find_actor()` as well? # above in `find_actor()` as well?
reg_portal: Portal reg_portal: Portal
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
sockaddrs = await reg_portal.run_from_ns( addrs = await reg_portal.run_from_ns(
'self', 'self',
'wait_for_actor', 'wait_for_actor',
name=name, name=name,
@ -308,8 +310,8 @@ async def wait_for_actor(
# get latest registered addr by default? # get latest registered addr by default?
# TODO: offer multi-portal yields in multi-homed case? # TODO: offer multi-portal yields in multi-homed case?
sockaddr: tuple[str, int] = sockaddrs[-1] addr: AddressTypes = addrs[-1]
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal

View File

@ -37,6 +37,7 @@ from .log import (
from . import _state from . import _state
from .devx import _debug from .devx import _debug
from .to_asyncio import run_as_asyncio_guest from .to_asyncio import run_as_asyncio_guest
from ._addr import AddressTypes
from ._runtime import ( from ._runtime import (
async_main, async_main,
Actor, Actor,
@ -52,10 +53,10 @@ log = get_logger(__name__)
def _mp_main( def _mp_main(
actor: Actor, actor: Actor,
accept_addrs: list[tuple[str, int]], accept_addrs: list[AddressTypes],
forkserver_info: tuple[Any, Any, Any, Any, Any], forkserver_info: tuple[Any, Any, Any, Any, Any],
start_method: SpawnMethodKey, start_method: SpawnMethodKey,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:
@ -206,7 +207,7 @@ def nest_from_op(
def _trio_main( def _trio_main(
actor: Actor, actor: Actor,
*, *,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:

View File

@ -43,21 +43,18 @@ from .devx import _debug
from . import _spawn from . import _spawn
from . import _state from . import _state
from . import log from . import log
from .ipc import _connect_chan from .ipc import (
_connect_chan,
)
from ._addr import (
AddressTypes,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._exceptions import is_multi_cancelled from ._exceptions import is_multi_cancelled
# set at startup and after forks
_default_host: str = '127.0.0.1'
_default_port: int = 1616
# default registry always on localhost
_default_lo_addrs: list[tuple[str, int]] = [(
_default_host,
_default_port,
)]
logger = log.get_logger('tractor') logger = log.get_logger('tractor')
@ -66,10 +63,12 @@ async def open_root_actor(
*, *,
# defaults are above # defaults are above
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
# defaults are above # defaults are above
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: tuple[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
name: str|None = 'root', name: str|None = 'root',
@ -195,11 +194,9 @@ async def open_root_actor(
) )
registry_addrs = [arbiter_addr] registry_addrs = [arbiter_addr]
registry_addrs: list[tuple[str, int]] = ( if not registry_addrs:
registry_addrs registry_addrs: list[AddressTypes] = default_lo_addrs(enable_transports)
or
_default_lo_addrs
)
assert registry_addrs assert registry_addrs
loglevel = ( loglevel = (
@ -248,10 +245,10 @@ async def open_root_actor(
enable_stack_on_sig() enable_stack_on_sig()
# closed into below ping task-func # closed into below ping task-func
ponged_addrs: list[tuple[str, int]] = [] ponged_addrs: list[AddressTypes] = []
async def ping_tpt_socket( async def ping_tpt_socket(
addr: tuple[str, int], addr: AddressTypes,
timeout: float = 1, timeout: float = 1,
) -> None: ) -> None:
''' '''
@ -271,7 +268,7 @@ async def open_root_actor(
# be better to eventually have a "discovery" protocol # be better to eventually have a "discovery" protocol
# with basic handshake instead? # with basic handshake instead?
with trio.move_on_after(timeout): with trio.move_on_after(timeout):
async with _connect_chan(*addr): async with _connect_chan(addr):
ponged_addrs.append(addr) ponged_addrs.append(addr)
except OSError: except OSError:
@ -284,10 +281,10 @@ async def open_root_actor(
for addr in registry_addrs: for addr in registry_addrs:
tn.start_soon( tn.start_soon(
ping_tpt_socket, ping_tpt_socket,
tuple(addr), # TODO: just drop this requirement? addr,
) )
trans_bind_addrs: list[tuple[str, int]] = [] trans_bind_addrs: list[AddressTypes] = []
# Create a new local root-actor instance which IS NOT THE # Create a new local root-actor instance which IS NOT THE
# REGISTRAR # REGISTRAR
@ -311,9 +308,12 @@ async def open_root_actor(
) )
# DO NOT use the registry_addrs as the transport server # DO NOT use the registry_addrs as the transport server
# addrs for this new non-registar, root-actor. # addrs for this new non-registar, root-actor.
for host, port in ponged_addrs: for addr in ponged_addrs:
# NOTE: zero triggers dynamic OS port allocation waddr = wrap_address(addr)
trans_bind_addrs.append((host, 0)) print(waddr)
trans_bind_addrs.append(
waddr.get_random(namespace=waddr.namespace)
)
# Start this local actor as the "registrar", aka a regular # Start this local actor as the "registrar", aka a regular
# actor who manages the local registry of "mailboxes" of # actor who manages the local registry of "mailboxes" of
@ -322,7 +322,7 @@ async def open_root_actor(
# NOTE that if the current actor IS THE REGISTAR, the # NOTE that if the current actor IS THE REGISTAR, the
# following init steps are taken: # following init steps are taken:
# - the tranport layer server is bound to each (host, port) # - the tranport layer server is bound to each addr
# pair defined in provided registry_addrs, or the default. # pair defined in provided registry_addrs, or the default.
trans_bind_addrs = registry_addrs trans_bind_addrs = registry_addrs
@ -462,7 +462,7 @@ def run_daemon(
# runtime kwargs # runtime kwargs
name: str | None = 'root', name: str | None = 'root',
registry_addrs: list[tuple[str, int]] = _default_lo_addrs, registry_addrs: list[AddressTypes]|None = None,
start_method: str | None = None, start_method: str | None = None,
debug_mode: bool = False, debug_mode: bool = False,

View File

@ -74,6 +74,13 @@ from tractor.msg import (
types as msgtypes, types as msgtypes,
) )
from .ipc import Channel from .ipc import Channel
from ._addr import (
AddressTypes,
Address,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._context import ( from ._context import (
mk_context, mk_context,
Context, Context,
@ -179,11 +186,11 @@ class Actor:
enable_modules: list[str] = [], enable_modules: list[str] = [],
uid: str|None = None, uid: str|None = None,
loglevel: str|None = None, loglevel: str|None = None,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
spawn_method: str|None = None, spawn_method: str|None = None,
# TODO: remove! # TODO: remove!
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: AddressTypes|None = None,
) -> None: ) -> None:
''' '''
@ -223,7 +230,7 @@ class Actor:
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
registry_addrs: list[tuple[str, int]] = [arbiter_addr] registry_addrs: list[AddressTypes] = [arbiter_addr]
# marked by the process spawning backend at startup # marked by the process spawning backend at startup
# will be None for the parent most process started manually # will be None for the parent most process started manually
@ -257,6 +264,7 @@ class Actor:
] = {} ] = {}
self._listeners: list[trio.abc.Listener] = [] self._listeners: list[trio.abc.Listener] = []
self._listen_addrs: list[Address] = []
self._parent_chan: Channel|None = None self._parent_chan: Channel|None = None
self._forkserver_info: tuple|None = None self._forkserver_info: tuple|None = None
@ -269,13 +277,13 @@ class Actor:
# when provided, init the registry addresses property from # when provided, init the registry addresses property from
# input via the validator. # input via the validator.
self._reg_addrs: list[tuple[str, int]] = [] self._reg_addrs: list[AddressTypes] = []
if registry_addrs: if registry_addrs:
self.reg_addrs: list[tuple[str, int]] = registry_addrs self.reg_addrs: list[AddressTypes] = registry_addrs
_state._runtime_vars['_registry_addrs'] = registry_addrs _state._runtime_vars['_registry_addrs'] = registry_addrs
@property @property
def reg_addrs(self) -> list[tuple[str, int]]: def reg_addrs(self) -> list[AddressTypes]:
''' '''
List of (socket) addresses for all known (and contactable) List of (socket) addresses for all known (and contactable)
registry actors. registry actors.
@ -286,7 +294,7 @@ class Actor:
@reg_addrs.setter @reg_addrs.setter
def reg_addrs( def reg_addrs(
self, self,
addrs: list[tuple[str, int]], addrs: list[AddressTypes],
) -> None: ) -> None:
if not addrs: if not addrs:
log.warning( log.warning(
@ -295,15 +303,6 @@ class Actor:
) )
return return
# always sanity check the input list since it's critical
# that addrs are correct for discovery sys operation.
for addr in addrs:
if not isinstance(addr, tuple):
raise ValueError(
'Expected `Actor.reg_addrs: list[tuple[str, int]]`\n'
f'Got {addrs}'
)
self._reg_addrs = addrs self._reg_addrs = addrs
async def wait_for_peer( async def wait_for_peer(
@ -1024,11 +1023,11 @@ class Actor:
async def _from_parent( async def _from_parent(
self, self,
parent_addr: tuple[str, int]|None, parent_addr: AddressTypes|None,
) -> tuple[ ) -> tuple[
Channel, Channel,
list[tuple[str, int]]|None, list[AddressTypes]|None,
]: ]:
''' '''
Bootstrap this local actor's runtime config from its parent by Bootstrap this local actor's runtime config from its parent by
@ -1040,16 +1039,13 @@ class Actor:
# Connect back to the parent actor and conduct initial # Connect back to the parent actor and conduct initial
# handshake. From this point on if we error, we # handshake. From this point on if we error, we
# attempt to ship the exception back to the parent. # attempt to ship the exception back to the parent.
chan = Channel( chan = await Channel.from_addr(wrap_address(parent_addr))
destaddr=parent_addr,
)
await chan.connect()
# TODO: move this into a `Channel.handshake()`? # TODO: move this into a `Channel.handshake()`?
# Initial handshake: swap names. # Initial handshake: swap names.
await self._do_handshake(chan) await self._do_handshake(chan)
accept_addrs: list[tuple[str, int]]|None = None accept_addrs: list[AddressTypes]|None = None
if self._spawn_method == "trio": if self._spawn_method == "trio":
@ -1066,7 +1062,7 @@ class Actor:
# if "trace"/"util" mode is enabled? # if "trace"/"util" mode is enabled?
f'{pretty_struct.pformat(spawnspec)}\n' f'{pretty_struct.pformat(spawnspec)}\n'
) )
accept_addrs: list[tuple[str, int]] = spawnspec.bind_addrs accept_addrs: list[AddressTypes] = spawnspec.bind_addrs
# TODO: another `Struct` for rtvs.. # TODO: another `Struct` for rtvs..
rvs: dict[str, Any] = spawnspec._runtime_vars rvs: dict[str, Any] = spawnspec._runtime_vars
@ -1173,8 +1169,7 @@ class Actor:
self, self,
handler_nursery: Nursery, handler_nursery: Nursery,
*, *,
# (host, port) to bind for channel server listen_addrs: list[AddressTypes]|None = None,
listen_sockaddrs: list[tuple[str, int]]|None = None,
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1186,41 +1181,45 @@ class Actor:
`.cancel_server()` is called. `.cancel_server()` is called.
''' '''
if listen_sockaddrs is None: if listen_addrs is None:
listen_sockaddrs = [(None, 0)] listen_addrs = default_lo_addrs([preferred_transport])
else:
listen_addrs: list[Address] = [
wrap_address(a) for a in listen_addrs
]
self._server_down = trio.Event() self._server_down = trio.Event()
try: try:
async with trio.open_nursery() as server_n: async with trio.open_nursery() as server_n:
listeners: list[trio.abc.Listener] = [
for host, port in listen_sockaddrs: await addr.open_listener()
listeners: list[trio.abc.Listener] = await server_n.start( for addr in listen_addrs
]
await server_n.start(
partial( partial(
trio.serve_tcp, trio.serve_listeners,
handler=self._stream_handler, handler=self._stream_handler,
port=port, listeners=listeners,
host=host,
# NOTE: configured such that new # NOTE: configured such that new
# connections will stay alive even if # connections will stay alive even if
# this server is cancelled! # this server is cancelled!
handler_nursery=handler_nursery, handler_nursery=handler_nursery
) )
) )
sockets: list[trio.socket] = [
getattr(listener, 'socket', 'unknown socket')
for listener in listeners
]
log.runtime( log.runtime(
'Started TCP server(s)\n' 'Started server(s)\n'
f'|_{sockets}\n' '\n'.join([f'|_{addr}' for addr in listen_addrs])
) )
self._listen_addrs.extend(listen_addrs)
self._listeners.extend(listeners) self._listeners.extend(listeners)
task_status.started(server_n) task_status.started(server_n)
finally: finally:
for addr in listen_addrs:
await addr.close_listener()
# signal the server is down since nursery above terminated # signal the server is down since nursery above terminated
self._server_down.set() self._server_down.set()
@ -1579,26 +1578,21 @@ class Actor:
return False return False
@property @property
def accept_addrs(self) -> list[tuple[str, int]]: def accept_addrs(self) -> list[AddressTypes]:
''' '''
All addresses to which the transport-channel server binds All addresses to which the transport-channel server binds
and listens for new connections. and listens for new connections.
''' '''
# throws OSError on failure return [a.unwrap() for a in self._listen_addrs]
return [
listener.socket.getsockname()
for listener in self._listeners
] # type: ignore
@property @property
def accept_addr(self) -> tuple[str, int]: def accept_addr(self) -> AddressTypes:
''' '''
Primary address to which the IPC transport server is Primary address to which the IPC transport server is
bound and listening for new connections. bound and listening for new connections.
''' '''
# throws OSError on failure
return self.accept_addrs[0] return self.accept_addrs[0]
def get_parent(self) -> Portal: def get_parent(self) -> Portal:
@ -1670,7 +1664,7 @@ class Actor:
async def async_main( async def async_main(
actor: Actor, actor: Actor,
accept_addrs: tuple[str, int]|None = None, accept_addrs: AddressTypes|None = None,
# XXX: currently ``parent_addr`` is only needed for the # XXX: currently ``parent_addr`` is only needed for the
# ``multiprocessing`` backend (which pickles state sent to # ``multiprocessing`` backend (which pickles state sent to
@ -1679,7 +1673,7 @@ async def async_main(
# change this to a simple ``is_subactor: bool`` which will # change this to a simple ``is_subactor: bool`` which will
# be False when running as root actor and True when as # be False when running as root actor and True when as
# a subactor. # a subactor.
parent_addr: tuple[str, int]|None = None, parent_addr: AddressTypes|None = None,
task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1769,7 +1763,7 @@ async def async_main(
partial( partial(
actor._serve_forever, actor._serve_forever,
service_nursery, service_nursery,
listen_sockaddrs=accept_addrs, listen_addrs=accept_addrs,
) )
) )
except OSError as oserr: except OSError as oserr:
@ -1785,7 +1779,7 @@ async def async_main(
raise raise
accept_addrs: list[tuple[str, int]] = actor.accept_addrs accept_addrs: list[AddressTypes] = actor.accept_addrs
# NOTE: only set the loopback addr for the # NOTE: only set the loopback addr for the
# process-tree-global "root" mailbox since # process-tree-global "root" mailbox since
@ -1793,9 +1787,8 @@ async def async_main(
# their root actor over that channel. # their root actor over that channel.
if _state._runtime_vars['_is_root']: if _state._runtime_vars['_is_root']:
for addr in accept_addrs: for addr in accept_addrs:
host, _ = addr waddr = wrap_address(addr)
# TODO: generic 'lo' detector predicate if waddr == waddr.get_root():
if '127.0.0.1' in host:
_state._runtime_vars['_root_mailbox'] = addr _state._runtime_vars['_root_mailbox'] = addr
# Register with the arbiter if we're told its addr # Register with the arbiter if we're told its addr
@ -1810,24 +1803,21 @@ async def async_main(
# only on unique actor uids? # only on unique actor uids?
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
try: try:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert addr[1] # non-zero after bind assert waddr.is_valid
except AssertionError: except AssertionError:
await _debug.pause() await _debug.pause()
async with get_registry(*addr) as reg_portal: async with get_registry(addr) as reg_portal:
for accept_addr in accept_addrs: for accept_addr in accept_addrs:
accept_addr = wrap_address(accept_addr)
if not accept_addr[1]: assert accept_addr.is_valid
await _debug.pause()
assert accept_addr[1]
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
'register_actor', 'register_actor',
uid=actor.uid, uid=actor.uid,
sockaddr=accept_addr, addr=accept_addr.unwrap(),
) )
is_registered: bool = True is_registered: bool = True
@ -1954,12 +1944,13 @@ async def async_main(
): ):
failed: bool = False failed: bool = False
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert waddr.is_valid
with trio.move_on_after(0.5) as cs: with trio.move_on_after(0.5) as cs:
cs.shield = True cs.shield = True
try: try:
async with get_registry( async with get_registry(
*addr, addr,
) as reg_portal: ) as reg_portal:
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
@ -2037,7 +2028,7 @@ class Arbiter(Actor):
self._registry: dict[ self._registry: dict[
tuple[str, str], tuple[str, str],
tuple[str, int], AddressTypes,
] = {} ] = {}
self._waiters: dict[ self._waiters: dict[
str, str,
@ -2053,18 +2044,18 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> tuple[str, int]|None: ) -> AddressTypes|None:
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
if name in uid: if name in uid:
return sockaddr return addr
return None return None
async def get_registry( async def get_registry(
self self
) -> dict[str, tuple[str, int]]: ) -> dict[str, AddressTypes]:
''' '''
Return current name registry. Return current name registry.
@ -2084,7 +2075,7 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> list[tuple[str, int]]: ) -> list[AddressTypes]:
''' '''
Wait for a particular actor to register. Wait for a particular actor to register.
@ -2092,44 +2083,41 @@ class Arbiter(Actor):
registered. registered.
''' '''
sockaddrs: list[tuple[str, int]] = [] addrs: list[AddressTypes] = []
sockaddr: tuple[str, int] addr: AddressTypes
mailbox_info: str = 'Actor registry contact infos:\n' mailbox_info: str = 'Actor registry contact infos:\n'
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
mailbox_info += ( mailbox_info += (
f'|_uid: {uid}\n' f'|_uid: {uid}\n'
f'|_sockaddr: {sockaddr}\n\n' f'|_addr: {addr}\n\n'
) )
if name == uid[0]: if name == uid[0]:
sockaddrs.append(sockaddr) addrs.append(addr)
if not sockaddrs: if not addrs:
waiter = trio.Event() waiter = trio.Event()
self._waiters.setdefault(name, []).append(waiter) self._waiters.setdefault(name, []).append(waiter)
await waiter.wait() await waiter.wait()
for uid in self._waiters[name]: for uid in self._waiters[name]:
if not isinstance(uid, trio.Event): if not isinstance(uid, trio.Event):
sockaddrs.append(self._registry[uid]) addrs.append(self._registry[uid])
log.runtime(mailbox_info) log.runtime(mailbox_info)
return sockaddrs return addrs
async def register_actor( async def register_actor(
self, self,
uid: tuple[str, str], uid: tuple[str, str],
sockaddr: tuple[str, int] addr: AddressTypes
) -> None: ) -> None:
uid = name, hash = (str(uid[0]), str(uid[1])) uid = name, hash = (str(uid[0]), str(uid[1]))
addr = (host, port) = ( waddr: Address = wrap_address(addr)
str(sockaddr[0]), if not waddr.is_valid:
int(sockaddr[1]), # should never be 0-dynamic-os-alloc
)
if port == 0:
await _debug.pause() await _debug.pause()
assert port # should never be 0-dynamic-os-alloc
self._registry[uid] = addr self._registry[uid] = addr
# pop and signal all waiter events # pop and signal all waiter events

View File

@ -46,6 +46,7 @@ from tractor._state import (
_runtime_vars, _runtime_vars,
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
from tractor._portal import Portal from tractor._portal import Portal
from tractor._runtime import Actor from tractor._runtime import Actor
from tractor._entry import _mp_main from tractor._entry import _mp_main
@ -392,8 +393,8 @@ async def new_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
@ -431,8 +432,8 @@ async def trio_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,
@ -520,15 +521,15 @@ async def trio_proc(
# send a "spawning specification" which configures the # send a "spawning specification" which configures the
# initial runtime state of the child. # initial runtime state of the child.
await chan.send( sspec = SpawnSpec(
SpawnSpec(
_parent_main_data=subactor._parent_main_data, _parent_main_data=subactor._parent_main_data,
enable_modules=subactor.enable_modules, enable_modules=subactor.enable_modules,
reg_addrs=subactor.reg_addrs, reg_addrs=subactor.reg_addrs,
bind_addrs=bind_addrs, bind_addrs=bind_addrs,
_runtime_vars=_runtime_vars, _runtime_vars=_runtime_vars,
) )
) log.runtime(f'Sending spawn spec: {str(sspec)}')
await chan.send(sspec)
# track subactor in current nursery # track subactor in current nursery
curr_actor: Actor = current_actor() curr_actor: Actor = current_actor()
@ -638,8 +639,8 @@ async def mp_proc(
subactor: Actor, subactor: Actor,
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,

View File

@ -28,7 +28,13 @@ import warnings
import trio import trio
from .devx._debug import maybe_wait_for_debugger from .devx._debug import maybe_wait_for_debugger
from ._addr import (
AddressTypes,
preferred_transport,
get_address_cls
)
from ._state import current_actor, is_main_process from ._state import current_actor, is_main_process
from .log import get_logger, get_loglevel from .log import get_logger, get_loglevel
from ._runtime import Actor from ._runtime import Actor
@ -47,8 +53,6 @@ if TYPE_CHECKING:
log = get_logger(__name__) log = get_logger(__name__)
_default_bind_addr: tuple[str, int] = ('127.0.0.1', 0)
class ActorNursery: class ActorNursery:
''' '''
@ -130,8 +134,9 @@ class ActorNursery:
*, *,
bind_addrs: list[tuple[str, int]] = [_default_bind_addr], bind_addrs: list[AddressTypes]|None = None,
rpc_module_paths: list[str]|None = None, rpc_module_paths: list[str]|None = None,
enable_transports: list[str] = [preferred_transport],
enable_modules: list[str]|None = None, enable_modules: list[str]|None = None,
loglevel: str|None = None, # set log level per subactor loglevel: str|None = None, # set log level per subactor
debug_mode: bool|None = None, debug_mode: bool|None = None,
@ -156,6 +161,12 @@ class ActorNursery:
or get_loglevel() or get_loglevel()
) )
if not bind_addrs:
bind_addrs: list[AddressTypes] = [
get_address_cls(transport).get_random().unwrap()
for transport in enable_transports
]
# configure and pass runtime state # configure and pass runtime state
_rtv = _state._runtime_vars.copy() _rtv = _state._runtime_vars.copy()
_rtv['_is_root'] = False _rtv['_is_root'] = False
@ -224,7 +235,7 @@ class ActorNursery:
*, *,
name: str | None = None, name: str | None = None,
bind_addrs: tuple[str, int] = [_default_bind_addr], bind_addrs: AddressTypes|None = None,
rpc_module_paths: list[str] | None = None, rpc_module_paths: list[str] | None = None,
enable_modules: list[str] | None = None, enable_modules: list[str] | None = None,
loglevel: str | None = None, # set log level per subactor loglevel: str | None = None, # set log level per subactor

View File

@ -2,19 +2,59 @@ import os
import random import random
def generate_single_byte_msgs(amount: int) -> bytes:
'''
Generate a byte instance of len `amount` with:
```
byte_at_index(i) = (i % 10).encode()
```
this results in constantly repeating sequences of:
b'0123456789'
'''
return b''.join(str(i % 10).encode() for i in range(amount))
def generate_sample_messages( def generate_sample_messages(
amount: int, amount: int,
rand_min: int = 0, rand_min: int = 0,
rand_max: int = 0, rand_max: int = 0,
silent: bool = False silent: bool = False,
) -> tuple[list[bytes], int]: ) -> tuple[list[bytes], int]:
'''
Generate bytes msgs for tests.
Messages will have the following format:
```
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
```
so for message index 25:
b'[00000025]' + random_bytes
'''
msgs = [] msgs = []
size = 0 size = 0
log_interval = None
if not silent: if not silent:
print(f'\ngenerating {amount} messages...') print(f'\ngenerating {amount} messages...')
# calculate an apropiate log interval based on
# max message size
max_msg_size = 10 + rand_max
if max_msg_size <= 32 * 1024:
log_interval = 10_000
else:
log_interval = 1000
for i in range(amount): for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8') msg = f'[{i:08}]'.encode('utf-8')
@ -26,7 +66,13 @@ def generate_sample_messages(
msgs.append(msg) msgs.append(msg)
if not silent and i and i % 10_000 == 0: if (
not silent
and
i > 0
and
i % log_interval == 0
):
print(f'{i} generated') print(f'{i} generated')
if not silent: if not silent:

View File

@ -13,20 +13,25 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import platform import platform
from ._transport import MsgTransport as MsgTransport from ._transport import (
MsgTransportKey as MsgTransportKey,
MsgType as MsgType,
MsgTransport as MsgTransport,
MsgpackTransport as MsgpackTransport
)
from ._tcp import ( from ._tcp import MsgpackTCPStream as MsgpackTCPStream
get_stream_addrs as get_stream_addrs, from ._uds import MsgpackUDSStream as MsgpackUDSStream
MsgpackTCPStream as MsgpackTCPStream
from ._types import (
transport_from_addr as transport_from_addr,
transport_from_stream as transport_from_stream,
) )
from ._chan import ( from ._chan import (
_connect_chan as _connect_chan, _connect_chan as _connect_chan,
get_msg_transport as get_msg_transport,
Channel as Channel Channel as Channel
) )
@ -39,12 +44,23 @@ if platform.system() == 'Linux':
write_eventfd as write_eventfd, write_eventfd as write_eventfd,
read_eventfd as read_eventfd, read_eventfd as read_eventfd,
close_eventfd as close_eventfd, close_eventfd as close_eventfd,
EFDReadCancelled as EFDReadCancelled,
EventFD as EventFD, EventFD as EventFD,
) )
from ._ringbuf import ( from ._ringbuf import (
RBToken as RBToken, RBToken as RBToken,
open_ringbuf as open_ringbuf,
RingBuffSender as RingBuffSender, RingBuffSender as RingBuffSender,
RingBuffReceiver as RingBuffReceiver, RingBuffReceiver as RingBuffReceiver,
open_ringbuf as open_ringbuf open_ringbuf_pair as open_ringbuf_pair,
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
RingBuffBytesSender as RingBuffBytesSender,
RingBuffBytesReceiver as RingBuffBytesReceiver,
RingBuffChannel as RingBuffChannel,
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel,
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
) )

View File

@ -29,15 +29,19 @@ from pprint import pformat
import typing import typing
from typing import ( from typing import (
Any, Any,
Type
) )
import trio import trio
from tractor.ipc._transport import MsgTransport from tractor.ipc._transport import MsgTransport
from tractor.ipc._tcp import ( from tractor.ipc._types import (
MsgpackTCPStream, transport_from_addr,
get_stream_addrs transport_from_stream,
)
from tractor._addr import (
wrap_address,
Address,
AddressTypes
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._exceptions import (
@ -52,17 +56,6 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows' _is_windows = platform.system() == 'Windows'
def get_msg_transport(
key: tuple[str, str],
) -> Type[MsgTransport]:
return {
('msgpack', 'tcp'): MsgpackTCPStream,
}[key]
class Channel: class Channel:
''' '''
An inter-process channel for communication between (remote) actors. An inter-process channel for communication between (remote) actors.
@ -77,10 +70,7 @@ class Channel:
def __init__( def __init__(
self, self,
destaddr: tuple[str, int]|None, transport: MsgTransport|None = None,
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
# TODO: optional reconnection support? # TODO: optional reconnection support?
# auto_reconnect: bool = False, # auto_reconnect: bool = False,
# on_reconnect: typing.Callable[..., typing.Awaitable] = None, # on_reconnect: typing.Callable[..., typing.Awaitable] = None,
@ -90,13 +80,9 @@ class Channel:
# self._recon_seq = on_reconnect # self._recon_seq = on_reconnect
# self._autorecon = auto_reconnect # self._autorecon = auto_reconnect
self._destaddr = destaddr
self._transport_key = msg_transport_type_key
# Either created in ``.connect()`` or passed in by # Either created in ``.connect()`` or passed in by
# user in ``.from_stream()``. # user in ``.from_stream()``.
self._stream: trio.SocketStream|None = None self._transport: MsgTransport|None = transport
self._transport: MsgTransport|None = None
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: tuple[str, str]|None = None self.uid: tuple[str, str]|None = None
@ -110,6 +96,10 @@ class Channel:
# runtime. # runtime.
self._cancel_called: bool = False self._cancel_called: bool = False
@property
def stream(self) -> trio.abc.Stream | None:
return self._transport.stream if self._transport else None
@property @property
def msgstream(self) -> MsgTransport: def msgstream(self) -> MsgTransport:
log.info( log.info(
@ -124,52 +114,32 @@ class Channel:
@classmethod @classmethod
def from_stream( def from_stream(
cls, cls,
stream: trio.SocketStream, stream: trio.abc.Stream,
**kwargs,
) -> Channel: ) -> Channel:
transport_cls = transport_from_stream(stream)
src, dst = get_stream_addrs(stream) return Channel(
chan = Channel( transport=transport_cls(stream)
destaddr=dst,
**kwargs,
) )
# set immediately here from provided instance @classmethod
chan._stream: trio.SocketStream = stream async def from_addr(
chan.set_msg_transport(stream) cls,
return chan addr: AddressTypes,
**kwargs
) -> Channel:
addr: Address = wrap_address(addr)
transport_cls = transport_from_addr(addr)
transport = await transport_cls.connect_to(addr, **kwargs)
def set_msg_transport( log.transport(
self, f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
stream: trio.SocketStream,
type_key: tuple[str, str]|None = None,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
codec: MsgCodec|None = None,
) -> MsgTransport:
type_key = (
type_key
or
self._transport_key
) )
# get transport type, then return Channel(transport=transport)
self._transport = get_msg_transport(
type_key
# instantiate an instance of the msg-transport
)(
stream,
codec=codec,
)
return self._transport
@cm @cm
def apply_codec( def apply_codec(
self, self,
codec: MsgCodec, codec: MsgCodec,
) -> None: ) -> None:
''' '''
Temporarily override the underlying IPC msg codec for Temporarily override the underlying IPC msg codec for
@ -189,44 +159,20 @@ class Channel:
return '<Channel with inactive transport?>' return '<Channel with inactive transport?>'
return repr( return repr(
self._transport.stream.socket._sock self._transport
).replace( # type: ignore ).replace( # type: ignore
"socket.socket", "socket.socket",
"Channel", "Channel",
) )
@property @property
def laddr(self) -> tuple[str, int]|None: def laddr(self) -> Address|None:
return self._transport.laddr if self._transport else None return self._transport.laddr if self._transport else None
@property @property
def raddr(self) -> tuple[str, int]|None: def raddr(self) -> Address|None:
return self._transport.raddr if self._transport else None return self._transport.raddr if self._transport else None
async def connect(
self,
destaddr: tuple[Any, ...] | None = None,
**kwargs
) -> MsgTransport:
if self.connected():
raise RuntimeError("channel is already connected?")
destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(
*destaddr,
**kwargs
)
transport = self.set_msg_transport(stream)
log.transport(
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
)
return transport
# TODO: something like, # TODO: something like,
# `pdbp.hideframe_on(errors=[MsgTypeError])` # `pdbp.hideframe_on(errors=[MsgTypeError])`
# instead of the `try/except` hack we have rn.. # instead of the `try/except` hack we have rn..
@ -261,8 +207,12 @@ class Channel:
# assert err # assert err
__tracebackhide__: bool = False __tracebackhide__: bool = False
else: else:
try:
assert err.cid assert err.cid
except KeyError:
raise err
raise raise
async def recv(self) -> Any: async def recv(self) -> Any:
@ -388,17 +338,14 @@ class Channel:
@acm @acm
async def _connect_chan( async def _connect_chan(
host: str, addr: AddressTypes
port: int
) -> typing.AsyncGenerator[Channel, None]: ) -> typing.AsyncGenerator[Channel, None]:
''' '''
Create and connect a channel with disconnect on context manager Create and connect a channel with disconnect on context manager
teardown. teardown.
''' '''
chan = Channel((host, port)) chan = await Channel.from_addr(addr)
await chan.connect()
yield chan yield chan
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
await chan.aclose() await chan.aclose()

View File

@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
raise OSError(errno.errorcode[ffi.errno], 'close failed') raise OSError(errno.errorcode[ffi.errno], 'close failed')
class EFDReadCancelled(Exception):
...
class EventFD: class EventFD:
''' '''
Use a previously opened eventfd(2), meant to be used in Use a previously opened eventfd(2), meant to be used in
@ -124,6 +128,7 @@ class EventFD:
self._fd: int = fd self._fd: int = fd
self._omode: str = omode self._omode: str = omode
self._fobj = None self._fobj = None
self._cscope: trio.CancelScope | None = None
@property @property
def fd(self) -> int | None: def fd(self) -> int | None:
@ -133,18 +138,47 @@ class EventFD:
return write_eventfd(self._fd, value) return write_eventfd(self._fd, value)
async def read(self) -> int: async def read(self) -> int:
'''
Async wrapper for `read_eventfd(self.fd)`
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
in order to make it cancellable when `self.close()` is called.
'''
self._cscope = trio.CancelScope()
with self._cscope:
return await trio.to_thread.run_sync( return await trio.to_thread.run_sync(
read_eventfd, self._fd, read_eventfd, self._fd,
abandon_on_cancel=True abandon_on_cancel=True
) )
if self._cscope.cancelled_caught:
raise EFDReadCancelled
self._cscope = None
def read_direct(self) -> int:
'''
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
opened with `EFD_NONBLOCK` its gonna block the thread.
'''
return read_eventfd(self._fd)
def open(self): def open(self):
self._fobj = os.fdopen(self._fd, self._omode) self._fobj = os.fdopen(self._fd, self._omode)
def close(self): def close(self):
if self._fobj: if self._fobj:
try:
self._fobj.close() self._fobj.close()
except OSError:
...
if self._cscope:
self._cscope.cancel()
def __enter__(self): def __enter__(self):
self.open() self.open()
return self return self

View File

@ -18,7 +18,15 @@ IPC Reliable RingBuffer implementation
''' '''
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager as cm import struct
from typing import (
ContextManager,
AsyncContextManager
)
from contextlib import (
contextmanager as cm,
asynccontextmanager as acm
)
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
import trio import trio
@ -28,25 +36,37 @@ from msgspec import (
) )
from ._linux import ( from ._linux import (
EFD_NONBLOCK,
open_eventfd, open_eventfd,
EFDReadCancelled,
EventFD EventFD
) )
from ._mp_bs import disable_mantracker from ._mp_bs import disable_mantracker
from tractor.log import get_logger
from tractor._exceptions import (
InternalError
)
log = get_logger(__name__)
disable_mantracker() disable_mantracker()
_DEFAULT_RB_SIZE = 10 * 1024
class RBToken(Struct, frozen=True): class RBToken(Struct, frozen=True):
''' '''
RingBuffer token contains necesary info to open the two RingBuffer token contains necesary info to open the three
eventfds and the shared memory eventfds and the shared memory
''' '''
shm_name: str shm_name: str
write_eventfd: int
wrap_eventfd: int write_eventfd: int # used to signal writer ptr advance
wrap_eventfd: int # used to signal reader ready after wrap around
eof_eventfd: int # used to signal writer closed
buf_size: int buf_size: int
def as_msg(self): def as_msg(self):
@ -59,24 +79,45 @@ class RBToken(Struct, frozen=True):
return RBToken(**msg) return RBToken(**msg)
@property
def fds(self) -> tuple[int, int, int]:
'''
Useful for `pass_fds` params
'''
return (
self.write_eventfd,
self.wrap_eventfd,
self.eof_eventfd
)
@cm @cm
def open_ringbuf( def open_ringbuf(
shm_name: str, shm_name: str,
buf_size: int = 10 * 1024, buf_size: int = _DEFAULT_RB_SIZE,
write_efd_flags: int = 0, ) -> ContextManager[RBToken]:
wrap_efd_flags: int = 0 '''
) -> RBToken: Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
'''
shm = SharedMemory( shm = SharedMemory(
name=shm_name, name=shm_name,
size=buf_size, size=buf_size,
create=True create=True
) )
try: try:
with (
EventFD(open_eventfd(), 'r') as write_event,
EventFD(open_eventfd(), 'r') as wrap_event,
EventFD(open_eventfd(), 'r') as eof_event,
):
token = RBToken( token = RBToken(
shm_name=shm_name, shm_name=shm_name,
write_eventfd=open_eventfd(flags=write_efd_flags), write_eventfd=write_event.fd,
wrap_eventfd=open_eventfd(flags=wrap_efd_flags), wrap_eventfd=wrap_event.fd,
eof_eventfd=eof_event.fd,
buf_size=buf_size buf_size=buf_size
) )
yield token yield token
@ -85,36 +126,50 @@ def open_ringbuf(
shm.unlink() shm.unlink()
Buffer = bytes | bytearray | memoryview
'''
IPC Reliable Ring Buffer
`eventfd(2)` is used for wrap around sync, to signal writes to
the reader and end of stream.
'''
class RingBuffSender(trio.abc.SendStream): class RingBuffSender(trio.abc.SendStream):
''' '''
IPC Reliable Ring Buffer sender side implementation Ring Buffer sender side implementation
`eventfd(2)` is used for wrap around sync, and also to signal Do not use directly! manage with `attach_to_ringbuf_sender`
writes to the reader. after having opened a ringbuf context with `open_ringbuf`.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
start_ptr: int = 0, cleanup: bool = False
): ):
token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
self._shm = SharedMemory( self._shm: SharedMemory | None = None
name=token.shm_name, self._write_event = EventFD(self._token.write_eventfd, 'w')
size=token.buf_size, self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
create=False self._eof_event = EventFD(self._token.eof_eventfd, 'w')
) self._ptr = 0
self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(token.wrap_eventfd, 'r') self._cleanup = cleanup
self._ptr = start_ptr self._send_lock = trio.StrictFIFOLock()
@property @property
def key(self) -> str: def name(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name return self._shm.name
@property @property
def size(self) -> int: def size(self) -> int:
return self._shm.size return self._token.buf_size
@property @property
def ptr(self) -> int: def ptr(self) -> int:
@ -128,7 +183,11 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def send_all(self, data: bytes | bytearray | memoryview): async def _wait_wrap(self):
await self._wrap_event.read()
async def send_all(self, data: Buffer):
async with self._send_lock:
# while data is larger than the remaining buf # while data is larger than the remaining buf
target_ptr = self.ptr + len(data) target_ptr = self.ptr + len(data)
while target_ptr > self.size: while target_ptr > self.size:
@ -137,7 +196,7 @@ class RingBuffSender(trio.abc.SendStream):
self._shm.buf[self.ptr:] = data[:remaining] self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around # signal write and wait for reader wrap around
self._write_event.write(remaining) self._write_event.write(remaining)
await self._wrap_event.read() await self._wait_wrap()
# wrap around and trim already written bytes # wrap around and trim already written bytes
self._ptr = 0 self._ptr = 0
@ -152,49 +211,69 @@ class RingBuffSender(trio.abc.SendStream):
async def wait_send_all_might_not_block(self): async def wait_send_all_might_not_block(self):
raise NotImplementedError raise NotImplementedError
async def aclose(self): def open(self):
self._write_event.close() self._shm = SharedMemory(
self._wrap_event.close() name=self._token.shm_name,
self._shm.close() size=self._token.buf_size,
create=False
async def __aenter__(self): )
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open()
def close(self):
self._eof_event.write(
self._ptr if self._ptr > 0 else self.size
)
if self._cleanup:
self._write_event.close()
self._wrap_event.close()
self._eof_event.close()
self._shm.close()
async def aclose(self):
async with self._send_lock:
self.close()
async def __aenter__(self):
self.open()
return self return self
class RingBuffReceiver(trio.abc.ReceiveStream): class RingBuffReceiver(trio.abc.ReceiveStream):
''' '''
IPC Reliable Ring Buffer receiver side implementation Ring Buffer receiver side implementation
`eventfd(2)` is used for wrap around sync, and also to signal Do not use directly! manage with `attach_to_ringbuf_receiver`
writes to the reader. after having opened a ringbuf context with `open_ringbuf`.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
start_ptr: int = 0, cleanup: bool = True,
flags: int = 0
): ):
token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
self._shm = SharedMemory( self._shm: SharedMemory | None = None
name=token.shm_name, self._write_event = EventFD(self._token.write_eventfd, 'w')
size=token.buf_size, self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
create=False self._eof_event = EventFD(self._token.eof_eventfd, 'r')
) self._ptr: int = 0
self._write_event = EventFD(token.write_eventfd, 'w') self._write_ptr: int = 0
self._wrap_event = EventFD(token.wrap_eventfd, 'r') self._end_ptr: int = -1
self._ptr = start_ptr
self._flags = flags self._cleanup: bool = cleanup
@property @property
def key(self) -> str: def name(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name return self._shm.name
@property @property
def size(self) -> int: def size(self) -> int:
return self._shm.size return self._token.buf_size
@property @property
def ptr(self) -> int: def ptr(self) -> int:
@ -208,46 +287,368 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def receive_some( async def _eof_monitor_task(self):
self, '''
max_bytes: int | None = None, Long running EOF event monitor, automatically run in bg by
nb_timeout: float = 0.1 `attach_to_ringbuf_receiver` context manager, if EOF event
) -> memoryview: is set its value will be the end pointer (highest valid
# if non blocking eventfd enabled, do polling index to be read from buf, after setting the `self._end_ptr`
# until next write, this allows signal handling we close the write event which should cancel any blocked
if self._flags | EFD_NONBLOCK: `self._write_event.read()`s on it.
delta = None
while delta is None: '''
try:
self._end_ptr = await self._eof_event.read()
self._write_event.close()
except EFDReadCancelled:
...
except trio.Cancelled:
...
async def receive_some(self, max_bytes: int | None = None) -> bytes:
'''
Receive up to `max_bytes`, if no `max_bytes` is provided
a reasonable default is used.
'''
if max_bytes is None:
max_bytes: int = _DEFAULT_RB_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# delta is remaining bytes we havent read
delta = self._write_ptr - self._ptr
if delta == 0:
# we have read all we can, see if new data is available
if self._end_ptr < 0:
# if we havent been signaled about EOF yet
try: try:
delta = await self._write_event.read() delta = await self._write_event.read()
self._write_ptr += delta
except OSError as e: except EFDReadCancelled:
if e.errno == 'EAGAIN': # while waiting for new data `self._write_event` was closed
continue # this means writer signaled EOF
if self._end_ptr > 0:
raise e # final self._write_ptr modification and recalculate delta
self._write_ptr = self._end_ptr
delta = self._end_ptr - self._ptr
else: else:
delta = await self._write_event.read() # shouldnt happen cause self._eof_monitor_task always sets
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
else:
# no more bytes to read and self._end_ptr set, EOF reached
return b''
# dont overflow caller
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta
# fetch next segment and advance ptr # fetch next segment and advance ptr
next_ptr = self._ptr + delta segment = bytes(self._shm.buf[self._ptr:target_ptr])
segment = self._shm.buf[self._ptr:next_ptr] self._ptr = target_ptr
self._ptr = next_ptr
if self.ptr == self.size: if self._ptr == self.size:
# reached the end, signal wrap around # reached the end, signal wrap around
self._ptr = 0 self._ptr = 0
self._write_ptr = 0
self._wrap_event.write(1) self._wrap_event.write(1)
return segment return segment
async def aclose(self): def open(self):
self._write_event.close() self._shm = SharedMemory(
self._wrap_event.close() name=self._token.shm_name,
self._shm.close() size=self._token.buf_size,
create=False
async def __aenter__(self): )
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open()
def close(self):
if self._cleanup:
self._write_event.close()
self._wrap_event.close()
self._eof_event.close()
self._shm.close()
async def aclose(self):
self.close()
async def __aenter__(self):
self.open()
return self return self
@acm
async def attach_to_ringbuf_receiver(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffReceiver]:
'''
Attach a RingBuffReceiver from a previously opened
RBToken.
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
'''
async with (
trio.open_nursery() as n,
RingBuffReceiver(
token,
cleanup=cleanup
) as receiver
):
n.start_soon(receiver._eof_monitor_task)
yield receiver
@acm
async def attach_to_ringbuf_sender(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffSender]:
'''
Attach a RingBuffSender from a previously opened
RBToken.
'''
async with RingBuffSender(
token,
cleanup=cleanup
) as sender:
yield sender
@cm
def open_ringbuf_pair(
name: str,
buf_size: int = _DEFAULT_RB_SIZE
) -> ContextManager[tuple(RBToken, RBToken)]:
'''
Handle resources for a ringbuf pair to be used for
bidirectional messaging.
'''
with (
open_ringbuf(
name + '.pair0',
buf_size=buf_size
) as token_0,
open_ringbuf(
name + '.pair1',
buf_size=buf_size
) as token_1
):
yield token_0, token_1
@acm
async def attach_to_ringbuf_stream(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
) -> AsyncContextManager[trio.StapledStream]:
'''
Attach a trio.StapledStream from a previously opened
ringbuf pair.
'''
async with (
attach_to_ringbuf_receiver(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_sender(
token_out,
cleanup=cleanup_out
) as sender,
):
yield trio.StapledStream(sender, receiver)
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
'''
In order to guarantee full messages are received, all bytes
sent by `RingBuffBytesSender` are preceded with a 4 byte header
which decodes into a uint32 indicating the actual size of the
next payload.
Optional batch mode:
If `batch_size` > 1 messages wont get sent immediately but will be
stored until `batch_size` messages are pending, then it will send
them all at once.
`batch_size` can be changed dynamically but always call, `flush()`
right before.
'''
def __init__(
self,
sender: RingBuffSender,
batch_size: int = 1
):
self._sender = sender
self.batch_size = batch_size
self._batch_msg_len = 0
self._batch: bytes = b''
async def flush(self) -> None:
await self._sender.send_all(self._batch)
self._batch = b''
self._batch_msg_len = 0
async def send(self, value: bytes) -> None:
msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1:
await self._sender.send_all(msg)
return
self._batch += msg
self._batch_msg_len += 1
if self._batch_msg_len == self.batch_size:
await self.flush()
async def aclose(self) -> None:
await self._sender.aclose()
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
'''
See `RingBuffBytesSender` docstring.
A `tricycle.BufferedReceiveStream` is used for the
`receive_exactly` API.
'''
def __init__(
self,
receiver: RingBuffReceiver
):
self._receiver = receiver
async def _receive_exactly(self, num_bytes: int) -> bytes:
'''
Fetch bytes from receiver until we read exactly `num_bytes`
or end of stream is signaled.
'''
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
new_bytes = await self._receiver.receive_some(
max_bytes=remaining
)
if new_bytes == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
async def receive(self) -> bytes:
header: bytes = await self._receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self._receive_exactly(size)
async def aclose(self) -> None:
await self._receiver.aclose()
@acm
async def attach_to_ringbuf_rchannel(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffBytesReceiver]:
'''
Attach a RingBuffBytesReceiver from a previously opened
RBToken.
'''
async with attach_to_ringbuf_receiver(
token, cleanup=cleanup
) as receiver:
yield RingBuffBytesReceiver(receiver)
@acm
async def attach_to_ringbuf_schannel(
token: RBToken,
cleanup: bool = True,
batch_size: int = 1,
) -> AsyncContextManager[RingBuffBytesSender]:
'''
Attach a RingBuffBytesSender from a previously opened
RBToken.
'''
async with attach_to_ringbuf_sender(
token, cleanup=cleanup
) as sender:
yield RingBuffBytesSender(sender, batch_size=batch_size)
class RingBuffChannel(trio.abc.Channel[bytes]):
'''
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
in order to expose the bidirectional `trio.abc.Channel` API.
'''
def __init__(
self,
sender: RingBuffBytesSender,
receiver: RingBuffBytesReceiver
):
self._sender = sender
self._receiver = receiver
async def send(self, value: bytes):
await self._sender.send(value)
async def receive(self) -> bytes:
return await self._receiver.receive()
async def aclose(self):
await self._receiver.aclose()
await self._sender.aclose()
@acm
async def attach_to_ringbuf_channel(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
) -> AsyncContextManager[RingBuffChannel]:
'''
Attach to an already opened ringbuf pair and return
a `RingBuffChannel`.
'''
async with (
attach_to_ringbuf_rchannel(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_schannel(
token_out,
cleanup=cleanup_out
) as sender,
):
yield RingBuffChannel(sender, receiver)

View File

@ -18,389 +18,88 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
''' '''
from __future__ import annotations from __future__ import annotations
from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
from typing import (
Any,
Callable,
Type,
)
import msgspec
from tricycle import BufferedReceiveStream
import trio import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._addr import TCPAddress
MsgTypeError, from tractor.ipc._transport import MsgpackTransport
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
from tractor.ipc import MsgTransport
log = get_logger(__name__) log = get_logger(__name__)
def get_stream_addrs(
stream: trio.SocketStream
) -> tuple[
tuple[str, int], # local
tuple[str, int], # remote
]:
'''
Return the `trio` streaming transport prot's socket-addrs for
both the local and remote sides as a pair.
'''
# rn, should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
# TODO: typing oddity.. not sure why we have to inherit here, but it # TODO: typing oddity.. not sure why we have to inherit here, but it
# seems to be an issue with `get_msg_transport()` returning # seems to be an issue with `get_msg_transport()` returning
# a `Type[Protocol]`; probably should make a `mypy` issue? # a `Type[Protocol]`; probably should make a `mypy` issue?
class MsgpackTCPStream(MsgTransport): class MsgpackTCPStream(MsgpackTransport):
''' '''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib. using the ``msgspec`` codec lib.
''' '''
address_type = TCPAddress
layer_key: int = 4 layer_key: int = 4
name_key: str = 'tcp'
# TODO: better naming for this? # def __init__(
# -[ ] check how libp2p does naming for such things? # self,
codec_key: str = 'msgpack' # stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
def __init__( # ) -> None:
self, # super().__init__(
stream: trio.SocketStream, # stream,
prefix_size: int = 4, # prefix_size=prefix_size,
# codec=codec
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
#
# TODO: define this as a `Codec` struct which can be
# overriden dynamically by the application/runtime?
codec: tuple[
Callable[[Any], Any]|None, # coder
Callable[[type, Any], Any]|None, # decoder
]|None = None,
) -> None:
self.stream = stream
assert self.stream.socket
# should both be IP sockets
self._laddr, self._raddr = get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# ) # )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
@property @property
def laddr(self) -> tuple[str, int]: def maddr(self) -> str:
return self._laddr host, port = self.raddr.unwrap()
return (
f'/ipv4/{host}'
f'/{self.address_type.name_key}/{port}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
@property # f'/cid={cid_head}..{cid_tail}'
def raddr(self) -> tuple[str, int]: # TODO: ? not use this ^ right ?
return self._raddr )
async def recv(self) -> Any:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
def connected(self) -> bool: def connected(self) -> bool:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
destaddr: TCPAddress,
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackTCPStream:
stream = await trio.open_tcp_stream(
*destaddr.unwrap(),
**kwargs
)
return MsgpackTCPStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[
tuple[str, int],
tuple[str, int]
]:
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
TCPAddress.from_addr(tuple(lsockname[:2])),
TCPAddress.from_addr(tuple(rsockname[:2])),
)

View File

@ -18,13 +18,45 @@ typing.Protocol based generic msg API, implement this class to add backends for
tractor.ipc.Channel tractor.ipc.Channel
''' '''
import trio from __future__ import annotations
from typing import ( from typing import (
runtime_checkable, runtime_checkable,
Type,
Protocol, Protocol,
TypeVar, TypeVar,
ClassVar
) )
from collections.abc import AsyncIterator from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
import trio
import msgspec
from tricycle import BufferedReceiveStream
from tractor.log import get_logger
from tractor._exceptions import (
MsgTypeError,
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
from tractor._addr import Address
log = get_logger(__name__)
# (codec, transport)
MsgTransportKey = tuple[str, str]
# from tractor.msg.types import MsgType # from tractor.msg.types import MsgType
@ -41,11 +73,11 @@ class MsgTransport(Protocol[MsgType]):
# eventual msg definition/types? # eventual msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol # - https://docs.python.org/3/library/typing.html#typing.Protocol
stream: trio.SocketStream stream: trio.abc.Stream
drained: list[MsgType] drained: list[MsgType]
def __init__(self, stream: trio.SocketStream) -> None: address_type: ClassVar[Type[Address]]
... codec_key: ClassVar[str]
# XXX: should this instead be called `.sendall()`? # XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None: async def send(self, msg: MsgType) -> None:
@ -65,10 +97,354 @@ class MsgTransport(Protocol[MsgType]):
def drain(self) -> AsyncIterator[dict]: def drain(self) -> AsyncIterator[dict]:
... ...
@classmethod
def key(cls) -> MsgTransportKey:
return cls.codec_key, cls.address_type.name_key
@property @property
def laddr(self) -> tuple[str, int]: def laddr(self) -> Address:
... ...
@property @property
def raddr(self) -> tuple[str, int]: def raddr(self) -> Address:
... ...
@property
def maddr(self) -> str:
...
@classmethod
async def connect_to(
cls,
addr: Address,
**kwargs
) -> MsgTransport:
...
@classmethod
def get_stream_addrs(
cls,
stream: trio.abc.Stream
) -> tuple[
Address, # local
Address # remote
]:
'''
Return the `trio` streaming transport prot's addrs for both
the local and remote sides as a pair.
'''
...
class MsgpackTransport(MsgTransport):
# TODO: better naming for this?
# -[ ] check how libp2p does naming for such things?
codec_key: str = 'msgpack'
def __init__(
self,
stream: trio.abc.Stream,
prefix_size: int = 4,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
#
# TODO: define this as a `Codec` struct which can be
# overriden dynamically by the application/runtime?
codec: MsgCodec = None,
) -> None:
self.stream = stream
self._laddr, self._raddr = self.get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
async def recv(self) -> msgtypes.MsgType:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
@property
def laddr(self) -> Address:
return self._laddr
@property
def raddr(self) -> Address:
return self._raddr

View File

@ -0,0 +1,99 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type
import trio
import socket
from tractor._addr import Address
from tractor.ipc._transport import (
MsgTransportKey,
MsgTransport
)
from tractor.ipc._tcp import MsgpackTCPStream
from tractor.ipc._uds import MsgpackUDSStream
# manually updated list of all supported msg transport types
_msg_transports = [
MsgpackTCPStream,
MsgpackUDSStream
]
# convert a MsgTransportKey to the corresponding transport type
_key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
cls.key(): cls
for cls in _msg_transports
}
# convert an Address wrapper to its corresponding transport type
_addr_to_transport: dict[Type[Address], Type[MsgTransport]] = {
cls.address_type: cls
for cls in _msg_transports
}
def transport_from_addr(
addr: Address,
codec_key: str = 'msgpack',
) -> Type[MsgTransport]:
'''
Given a destination address and a desired codec, find the
corresponding `MsgTransport` type.
'''
try:
return _addr_to_transport[type(addr)]
except KeyError:
raise NotImplementedError(
f'No known transport for address {repr(addr)}'
)
def transport_from_stream(
stream: trio.abc.Stream,
codec_key: str = 'msgpack'
) -> Type[MsgTransport]:
'''
Given an arbitrary `trio.abc.Stream` and a desired codec,
find the corresponding `MsgTransport` type.
'''
transport = None
if isinstance(stream, trio.SocketStream):
sock = stream.socket
match sock.family:
case socket.AF_INET | socket.AF_INET6:
transport = 'tcp'
case socket.AF_UNIX:
transport = 'uds'
case _:
raise NotImplementedError(
f'Unsupported socket family: {sock.family}'
)
if not transport:
raise NotImplementedError(
f'Could not figure out transport type for stream type {type(stream)}'
)
key = (codec_key, transport)
return _key_to_transport[key]

View File

@ -0,0 +1,97 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
'''
from __future__ import annotations
import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger
from tractor._addr import UDSAddress
from tractor.ipc._transport import MsgpackTransport
log = get_logger(__name__)
class MsgpackUDSStream(MsgpackTransport):
'''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib.
'''
address_type = UDSAddress
layer_key: int = 7
# def __init__(
# self,
# stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
# ) -> None:
# super().__init__(
# stream,
# prefix_size=prefix_size,
# codec=codec
# )
@property
def maddr(self) -> str:
filepath = self.raddr.unwrap()
return (
f'/ipv4/localhost'
f'/{self.address_type.name_key}/{filepath}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
def connected(self) -> bool:
return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
addr: UDSAddress,
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackUDSStream:
stream = await trio.open_unix_socket(
addr.unwrap(),
**kwargs
)
return MsgpackUDSStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[UDSAddress, UDSAddress]:
return (
UDSAddress.from_addr(stream.socket.getsockname()),
UDSAddress.from_addr(stream.socket.getsockname()),
)

View File

@ -47,6 +47,7 @@ from tractor.msg import (
pretty_struct, pretty_struct,
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
log = get_logger('tractor.msgspec') log = get_logger('tractor.msgspec')
@ -167,8 +168,8 @@ class SpawnSpec(
# TODO: not just sockaddr pairs? # TODO: not just sockaddr pairs?
# -[ ] abstract into a `TransportAddr` type? # -[ ] abstract into a `TransportAddr` type?
reg_addrs: list[tuple[str, int]] reg_addrs: list[AddressTypes]
bind_addrs: list[tuple[str, int]] bind_addrs: list[AddressTypes]
# TODO: caps based RPC support in the payload? # TODO: caps based RPC support in the payload?