Compare commits
11 Commits
1762b3eb64
...
efd11f7d74
Author | SHA1 | Date |
---|---|---|
|
efd11f7d74 | |
|
76cee99fc2 | |
|
5f50206d84 | |
|
a47a7a39b1 | |
|
bab265b2d8 | |
|
010874bed5 | |
|
ea010ab46a | |
|
be7fc89ae9 | |
|
2a9a78651b | |
|
be818a720a | |
|
ba353bf46f |
|
@ -10,9 +10,10 @@ pkgs.mkShell {
|
|||
inherit nativeBuildInputs;
|
||||
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
|
||||
TMPDIR = "/tmp";
|
||||
|
||||
shellHook = ''
|
||||
set -e
|
||||
uv venv .venv --python=3.12
|
||||
uv venv .venv --python=3.11
|
||||
'';
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ async def main(service_name):
|
|||
async with tractor.open_nursery() as an:
|
||||
await an.start_actor(service_name)
|
||||
|
||||
async with tractor.get_registry('127.0.0.1', 1616) as portal:
|
||||
async with tractor.get_registry() as portal:
|
||||
print(f"Arbiter is listening on {portal.channel}")
|
||||
|
||||
async with tractor.wait_for_actor(service_name) as sockaddr:
|
||||
|
|
|
@ -26,7 +26,7 @@ async def test_reg_then_unreg(reg_addr):
|
|||
portal = await n.start_actor('actor', enable_modules=[__name__])
|
||||
uid = portal.channel.uid
|
||||
|
||||
async with tractor.get_registry(*reg_addr) as aportal:
|
||||
async with tractor.get_registry(reg_addr) as aportal:
|
||||
# this local actor should be the arbiter
|
||||
assert actor is aportal.actor
|
||||
|
||||
|
@ -160,7 +160,7 @@ async def spawn_and_check_registry(
|
|||
async with tractor.open_root_actor(
|
||||
registry_addrs=[reg_addr],
|
||||
):
|
||||
async with tractor.get_registry(*reg_addr) as portal:
|
||||
async with tractor.get_registry(reg_addr) as portal:
|
||||
# runtime needs to be up to call this
|
||||
actor = tractor.current_actor()
|
||||
|
||||
|
@ -300,7 +300,7 @@ async def close_chans_before_nursery(
|
|||
async with tractor.open_root_actor(
|
||||
registry_addrs=[reg_addr],
|
||||
):
|
||||
async with tractor.get_registry(*reg_addr) as aportal:
|
||||
async with tractor.get_registry(reg_addr) as aportal:
|
||||
try:
|
||||
get_reg = partial(unpack_reg, aportal)
|
||||
|
||||
|
|
|
@ -66,6 +66,9 @@ def run_example_in_subproc(
|
|||
# due to backpressure!!!
|
||||
proc = testdir.popen(
|
||||
cmdargs,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
**kwargs,
|
||||
)
|
||||
assert not proc.returncode
|
||||
|
@ -119,10 +122,14 @@ def test_example(
|
|||
code = ex.read()
|
||||
|
||||
with run_example_in_subproc(code) as proc:
|
||||
proc.wait()
|
||||
err, _ = proc.stderr.read(), proc.stdout.read()
|
||||
# print(f'STDERR: {err}')
|
||||
# print(f'STDOUT: {out}')
|
||||
err = None
|
||||
try:
|
||||
if not proc.poll():
|
||||
_, err = proc.communicate(timeout=15)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
proc.kill()
|
||||
err = e.stderr
|
||||
|
||||
# if we get some gnarly output let's aggregate and raise
|
||||
if err:
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import trio
|
||||
import pytest
|
||||
from tractor.ipc import (
|
||||
open_eventfd,
|
||||
EFDReadCancelled,
|
||||
EventFD
|
||||
)
|
||||
|
||||
|
||||
def test_eventfd_read_cancellation():
|
||||
'''
|
||||
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
|
||||
is called.
|
||||
|
||||
'''
|
||||
fd = open_eventfd()
|
||||
|
||||
async def _read(event: EventFD):
|
||||
with pytest.raises(EFDReadCancelled):
|
||||
await event.read()
|
||||
|
||||
async def main():
|
||||
async with trio.open_nursery() as n:
|
||||
with (
|
||||
EventFD(fd, 'w') as event,
|
||||
trio.fail_after(3)
|
||||
):
|
||||
n.start_soon(_read, event)
|
||||
await trio.sleep(0.2)
|
||||
event.close()
|
||||
|
||||
trio.run(main)
|
|
@ -871,7 +871,7 @@ async def serve_subactors(
|
|||
)
|
||||
await ipc.send((
|
||||
peer.chan.uid,
|
||||
peer.chan.raddr,
|
||||
peer.chan.raddr.unwrap(),
|
||||
))
|
||||
|
||||
print('Spawner exiting spawn serve loop!')
|
||||
|
|
|
@ -38,7 +38,7 @@ async def test_self_is_registered_localportal(reg_addr):
|
|||
"Verify waiting on the arbiter to register itself using a local portal."
|
||||
actor = tractor.current_actor()
|
||||
assert actor.is_arbiter
|
||||
async with tractor.get_registry(*reg_addr) as portal:
|
||||
async with tractor.get_registry(reg_addr) as portal:
|
||||
assert isinstance(portal, tractor._portal.LocalPortal)
|
||||
|
||||
with trio.fail_after(0.2):
|
||||
|
|
|
@ -32,7 +32,7 @@ def test_abort_on_sigint(daemon):
|
|||
@tractor_test
|
||||
async def test_cancel_remote_arbiter(daemon, reg_addr):
|
||||
assert not tractor.current_actor().is_arbiter
|
||||
async with tractor.get_registry(*reg_addr) as portal:
|
||||
async with tractor.get_registry(reg_addr) as portal:
|
||||
await portal.cancel_actor()
|
||||
|
||||
time.sleep(0.1)
|
||||
|
@ -41,7 +41,7 @@ async def test_cancel_remote_arbiter(daemon, reg_addr):
|
|||
|
||||
# no arbiter socket should exist
|
||||
with pytest.raises(OSError):
|
||||
async with tractor.get_registry(*reg_addr) as portal:
|
||||
async with tractor.get_registry(reg_addr) as portal:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1,15 +1,21 @@
|
|||
import time
|
||||
import hashlib
|
||||
|
||||
import trio
|
||||
import pytest
|
||||
import tractor
|
||||
from tractor.ipc import (
|
||||
open_ringbuf,
|
||||
attach_to_ringbuf_receiver,
|
||||
attach_to_ringbuf_sender,
|
||||
attach_to_ringbuf_stream,
|
||||
attach_to_ringbuf_channel,
|
||||
RBToken,
|
||||
RingBuffSender,
|
||||
RingBuffReceiver
|
||||
)
|
||||
from tractor._testing.samples import generate_sample_messages
|
||||
from tractor._testing.samples import (
|
||||
generate_single_byte_msgs,
|
||||
generate_sample_messages
|
||||
)
|
||||
|
||||
|
||||
@tractor.context
|
||||
|
@ -17,19 +23,27 @@ async def child_read_shm(
|
|||
ctx: tractor.Context,
|
||||
msg_amount: int,
|
||||
token: RBToken,
|
||||
total_bytes: int,
|
||||
) -> None:
|
||||
recvd_bytes = 0
|
||||
await ctx.started()
|
||||
start_ts = time.time()
|
||||
async with RingBuffReceiver(token) as receiver:
|
||||
while recvd_bytes < total_bytes:
|
||||
msg = await receiver.receive_some()
|
||||
recvd_bytes += len(msg)
|
||||
) -> str:
|
||||
'''
|
||||
Sub-actor used in `test_ringbuf`.
|
||||
|
||||
# make sure we dont hold any memoryviews
|
||||
# before the ctx manager aclose()
|
||||
msg = None
|
||||
Attach to a ringbuf and receive all messages until end of stream.
|
||||
Keep track of how many bytes received and also calculate
|
||||
sha256 of the whole byte stream.
|
||||
|
||||
Calculate and print performance stats, finally return calculated
|
||||
hash.
|
||||
|
||||
'''
|
||||
await ctx.started()
|
||||
print('reader started')
|
||||
recvd_bytes = 0
|
||||
recvd_hash = hashlib.sha256()
|
||||
start_ts = time.time()
|
||||
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||
async for msg in receiver:
|
||||
recvd_hash.update(msg)
|
||||
recvd_bytes += len(msg)
|
||||
|
||||
end_ts = time.time()
|
||||
elapsed = end_ts - start_ts
|
||||
|
@ -38,6 +52,9 @@ async def child_read_shm(
|
|||
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||
print(f'\treceived bytes: {recvd_bytes:,}')
|
||||
|
||||
return recvd_hash.hexdigest()
|
||||
|
||||
|
||||
@tractor.context
|
||||
|
@ -48,16 +65,32 @@ async def child_write_shm(
|
|||
rand_max: int,
|
||||
token: RBToken,
|
||||
) -> None:
|
||||
msgs, total_bytes = generate_sample_messages(
|
||||
'''
|
||||
Sub-actor used in `test_ringbuf`
|
||||
|
||||
Generate `msg_amount` payloads with
|
||||
`random.randint(rand_min, rand_max)` random bytes at the end,
|
||||
Calculate sha256 hash and send it to parent on `ctx.started`.
|
||||
|
||||
Attach to ringbuf and send all generated messages.
|
||||
|
||||
'''
|
||||
msgs, _total_bytes = generate_sample_messages(
|
||||
msg_amount,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
)
|
||||
await ctx.started(total_bytes)
|
||||
async with RingBuffSender(token) as sender:
|
||||
print('writer hashing payload...')
|
||||
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
|
||||
print('writer done hashing.')
|
||||
await ctx.started(sent_hash)
|
||||
print('writer started')
|
||||
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||
for msg in msgs:
|
||||
await sender.send_all(msg)
|
||||
|
||||
print('writer exit')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'msg_amount,rand_min,rand_max,buf_size',
|
||||
|
@ -83,19 +116,23 @@ def test_ringbuf(
|
|||
rand_max: int,
|
||||
buf_size: int
|
||||
):
|
||||
'''
|
||||
- Open a new ring buf on root actor
|
||||
- Open `child_write_shm` ctx in sub-actor which will generate a
|
||||
random payload and send its hash on `ctx.started`, finally sending
|
||||
the payload through the stream.
|
||||
- Open `child_read_shm` ctx in sub-actor which will receive the
|
||||
payload, calculate perf stats and return the hash.
|
||||
- Compare both hashes
|
||||
|
||||
'''
|
||||
async def main():
|
||||
with open_ringbuf(
|
||||
'test_ringbuf',
|
||||
buf_size=buf_size
|
||||
) as token:
|
||||
proc_kwargs = {
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
}
|
||||
proc_kwargs = {'pass_fds': token.fds}
|
||||
|
||||
common_kwargs = {
|
||||
'msg_amount': msg_amount,
|
||||
'token': token,
|
||||
}
|
||||
async with tractor.open_nursery() as an:
|
||||
send_p = await an.start_actor(
|
||||
'ring_sender',
|
||||
|
@ -110,17 +147,20 @@ def test_ringbuf(
|
|||
async with (
|
||||
send_p.open_context(
|
||||
child_write_shm,
|
||||
token=token,
|
||||
msg_amount=msg_amount,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
**common_kwargs
|
||||
) as (sctx, total_bytes),
|
||||
) as (_sctx, sent_hash),
|
||||
recv_p.open_context(
|
||||
child_read_shm,
|
||||
**common_kwargs,
|
||||
total_bytes=total_bytes,
|
||||
) as (sctx, _sent),
|
||||
token=token,
|
||||
msg_amount=msg_amount
|
||||
) as (rctx, _sent),
|
||||
):
|
||||
await recv_p.result()
|
||||
recvd_hash = await rctx.result()
|
||||
|
||||
assert sent_hash == recvd_hash
|
||||
|
||||
await send_p.cancel_actor()
|
||||
await recv_p.cancel_actor()
|
||||
|
@ -134,23 +174,28 @@ async def child_blocked_receiver(
|
|||
ctx: tractor.Context,
|
||||
token: RBToken
|
||||
):
|
||||
async with RingBuffReceiver(token) as receiver:
|
||||
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||
await ctx.started()
|
||||
await receiver.receive_some()
|
||||
|
||||
|
||||
def test_ring_reader_cancel():
|
||||
def test_reader_cancel():
|
||||
'''
|
||||
Test that a receiver blocked on eventfd(2) read responds to
|
||||
cancellation.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
with open_ringbuf('test_ring_cancel_reader') as token:
|
||||
async with (
|
||||
tractor.open_nursery() as an,
|
||||
RingBuffSender(token) as _sender,
|
||||
attach_to_ringbuf_sender(token) as _sender,
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
'ring_blocked_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
'pass_fds': token.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
|
@ -172,12 +217,17 @@ async def child_blocked_sender(
|
|||
ctx: tractor.Context,
|
||||
token: RBToken
|
||||
):
|
||||
async with RingBuffSender(token) as sender:
|
||||
async with attach_to_ringbuf_sender(token) as sender:
|
||||
await ctx.started()
|
||||
await sender.send_all(b'this will wrap')
|
||||
|
||||
|
||||
def test_ring_sender_cancel():
|
||||
def test_sender_cancel():
|
||||
'''
|
||||
Test that a sender blocked on eventfd(2) read responds to
|
||||
cancellation.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
with open_ringbuf(
|
||||
'test_ring_cancel_sender',
|
||||
|
@ -188,7 +238,7 @@ def test_ring_sender_cancel():
|
|||
'ring_blocked_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
'pass_fds': token.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
|
@ -203,3 +253,171 @@ def test_ring_sender_cancel():
|
|||
|
||||
with pytest.raises(tractor._exceptions.ContextCancelled):
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_receiver_max_bytes():
|
||||
'''
|
||||
Test that RingBuffReceiver.receive_some's max_bytes optional
|
||||
argument works correctly, send a msg of size 100, then
|
||||
force receive of messages with max_bytes == 1, wait until
|
||||
100 of these messages are received, then compare join of
|
||||
msgs with original message
|
||||
|
||||
'''
|
||||
msg = generate_single_byte_msgs(100)
|
||||
msgs = []
|
||||
|
||||
async def main():
|
||||
with open_ringbuf(
|
||||
'test_ringbuf_max_bytes',
|
||||
buf_size=10
|
||||
) as token:
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
attach_to_ringbuf_sender(token, cleanup=False) as sender,
|
||||
attach_to_ringbuf_receiver(token, cleanup=False) as receiver
|
||||
):
|
||||
async def _send_and_close():
|
||||
await sender.send_all(msg)
|
||||
await sender.aclose()
|
||||
|
||||
n.start_soon(_send_and_close)
|
||||
while len(msgs) < len(msg):
|
||||
msg_part = await receiver.receive_some(max_bytes=1)
|
||||
assert len(msg_part) == 1
|
||||
msgs.append(msg_part)
|
||||
|
||||
trio.run(main)
|
||||
assert msg == b''.join(msgs)
|
||||
|
||||
|
||||
def test_stapled_ringbuf():
|
||||
'''
|
||||
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
|
||||
are inversed on each task) which will open the streams and use trio.StapledStream
|
||||
to have a single bidirectional stream.
|
||||
|
||||
Then take turns to send and receive messages.
|
||||
|
||||
'''
|
||||
msg = generate_single_byte_msgs(100)
|
||||
pair_0_msgs = []
|
||||
pair_1_msgs = []
|
||||
|
||||
pair_0_done = trio.Event()
|
||||
pair_1_done = trio.Event()
|
||||
|
||||
async def pair_0(token_in: RBToken, token_out: RBToken):
|
||||
async with attach_to_ringbuf_stream(
|
||||
token_in,
|
||||
token_out,
|
||||
cleanup_in=False,
|
||||
cleanup_out=False
|
||||
) as stream:
|
||||
# first turn to send
|
||||
await stream.send_all(msg)
|
||||
|
||||
# second turn to receive
|
||||
while len(pair_0_msgs) != len(msg):
|
||||
_msg = await stream.receive_some(max_bytes=1)
|
||||
pair_0_msgs.append(_msg)
|
||||
|
||||
pair_0_done.set()
|
||||
await pair_1_done.wait()
|
||||
|
||||
|
||||
async def pair_1(token_in: RBToken, token_out: RBToken):
|
||||
async with attach_to_ringbuf_stream(
|
||||
token_in,
|
||||
token_out,
|
||||
cleanup_in=False,
|
||||
cleanup_out=False
|
||||
) as stream:
|
||||
# first turn to receive
|
||||
while len(pair_1_msgs) != len(msg):
|
||||
_msg = await stream.receive_some(max_bytes=1)
|
||||
pair_1_msgs.append(_msg)
|
||||
|
||||
# second turn to send
|
||||
await stream.send_all(msg)
|
||||
|
||||
pair_1_done.set()
|
||||
await pair_0_done.wait()
|
||||
|
||||
|
||||
async def main():
|
||||
with tractor.ipc.open_ringbuf_pair(
|
||||
'test_stapled_ringbuf'
|
||||
) as (token_0, token_1):
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(pair_0, token_0, token_1)
|
||||
n.start_soon(pair_1, token_1, token_0)
|
||||
|
||||
|
||||
trio.run(main)
|
||||
|
||||
assert msg == b''.join(pair_0_msgs)
|
||||
assert msg == b''.join(pair_1_msgs)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def child_channel_sender(
|
||||
ctx: tractor.Context,
|
||||
msg_amount_min: int,
|
||||
msg_amount_max: int,
|
||||
token_in: RBToken,
|
||||
token_out: RBToken
|
||||
):
|
||||
import random
|
||||
msgs, _total_bytes = generate_sample_messages(
|
||||
random.randint(msg_amount_min, msg_amount_max),
|
||||
rand_min=256,
|
||||
rand_max=1024,
|
||||
)
|
||||
async with attach_to_ringbuf_channel(
|
||||
token_in,
|
||||
token_out
|
||||
) as chan:
|
||||
await ctx.started(msgs)
|
||||
|
||||
for msg in msgs:
|
||||
await chan.send(msg)
|
||||
|
||||
|
||||
def test_channel():
|
||||
|
||||
msg_amount_min = 100
|
||||
msg_amount_max = 1000
|
||||
|
||||
async def main():
|
||||
with tractor.ipc.open_ringbuf_pair(
|
||||
'test_ringbuf_transport'
|
||||
) as (token_0, token_1):
|
||||
async with (
|
||||
attach_to_ringbuf_channel(token_0, token_1) as chan,
|
||||
tractor.open_nursery() as an
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
'test_ringbuf_transport_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': token_0.fds + token_1.fds
|
||||
}
|
||||
)
|
||||
async with (
|
||||
recv_p.open_context(
|
||||
child_channel_sender,
|
||||
msg_amount_min=msg_amount_min,
|
||||
msg_amount_max=msg_amount_max,
|
||||
token_in=token_1,
|
||||
token_out=token_0
|
||||
) as (ctx, msgs),
|
||||
):
|
||||
recv_msgs = []
|
||||
async for msg in chan:
|
||||
recv_msgs.append(msg)
|
||||
|
||||
await recv_p.cancel_actor()
|
||||
assert recv_msgs == msgs
|
||||
|
||||
trio.run(main)
|
||||
|
|
|
@ -77,7 +77,7 @@ async def movie_theatre_question():
|
|||
async def test_movie_theatre_convo(start_method):
|
||||
"""The main ``tractor`` routine.
|
||||
"""
|
||||
async with tractor.open_nursery() as n:
|
||||
async with tractor.open_nursery(debug_mode=True) as n:
|
||||
|
||||
portal = await n.start_actor(
|
||||
'frank',
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
# tractor: structured concurrent "actors".
|
||||
# Copyright 2018-eternity Tyler Goodlet.
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
from typing import (
|
||||
Protocol,
|
||||
ClassVar,
|
||||
TypeVar,
|
||||
Union,
|
||||
Type
|
||||
)
|
||||
|
||||
import trio
|
||||
from trio import socket
|
||||
|
||||
|
||||
NamespaceType = TypeVar('NamespaceType')
|
||||
AddressType = TypeVar('AddressType')
|
||||
StreamType = TypeVar('StreamType')
|
||||
ListenerType = TypeVar('ListenerType')
|
||||
|
||||
|
||||
class Address(Protocol[
|
||||
NamespaceType,
|
||||
AddressType,
|
||||
StreamType,
|
||||
ListenerType
|
||||
]):
|
||||
|
||||
name_key: ClassVar[str]
|
||||
address_type: ClassVar[Type[AddressType]]
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def namespace(self) -> NamespaceType|None:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def from_addr(cls, addr: AddressType) -> Address:
|
||||
...
|
||||
|
||||
def unwrap(self) -> AddressType:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_random(cls, namespace: NamespaceType | None = None) -> Address:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_root(cls) -> Address:
|
||||
...
|
||||
|
||||
def __repr__(self) -> str:
|
||||
...
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
...
|
||||
|
||||
async def open_stream(self, **kwargs) -> StreamType:
|
||||
...
|
||||
|
||||
async def open_listener(self, **kwargs) -> ListenerType:
|
||||
...
|
||||
|
||||
async def close_listener(self):
|
||||
...
|
||||
|
||||
|
||||
class TCPAddress(Address[
|
||||
str,
|
||||
tuple[str, int],
|
||||
trio.SocketStream,
|
||||
trio.SocketListener
|
||||
]):
|
||||
|
||||
name_key: str = 'tcp'
|
||||
address_type: type = tuple[str, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int
|
||||
):
|
||||
if (
|
||||
not isinstance(host, str)
|
||||
or
|
||||
not isinstance(port, int)
|
||||
):
|
||||
raise TypeError(f'Expected host {host} to be str and port {port} to be int')
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self._port != 0
|
||||
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return self._host
|
||||
|
||||
@classmethod
|
||||
def from_addr(cls, addr: tuple[str, int]) -> TCPAddress:
|
||||
return TCPAddress(addr[0], addr[1])
|
||||
|
||||
def unwrap(self) -> tuple[str, int]:
|
||||
return self._host, self._port
|
||||
|
||||
@classmethod
|
||||
def get_random(cls, namespace: str = '127.0.0.1') -> TCPAddress:
|
||||
return TCPAddress(namespace, 0)
|
||||
|
||||
@classmethod
|
||||
def get_root(cls) -> Address:
|
||||
return TCPAddress('127.0.0.1', 1616)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{type(self)} @ {self.unwrap()}'
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, TCPAddress):
|
||||
raise TypeError(
|
||||
f'Can not compare {type(other)} with {type(self)}'
|
||||
)
|
||||
|
||||
return (
|
||||
self._host == other._host
|
||||
and
|
||||
self._port == other._port
|
||||
)
|
||||
|
||||
async def open_stream(self, **kwargs) -> trio.SocketStream:
|
||||
stream = await trio.open_tcp_stream(
|
||||
self._host,
|
||||
self._port,
|
||||
**kwargs
|
||||
)
|
||||
self._host, self._port = stream.socket.getsockname()[:2]
|
||||
return stream
|
||||
|
||||
async def open_listener(self, **kwargs) -> trio.SocketListener:
|
||||
listeners = await trio.open_tcp_listeners(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
**kwargs
|
||||
)
|
||||
assert len(listeners) == 1
|
||||
listener = listeners[0]
|
||||
self._host, self._port = listener.socket.getsockname()[:2]
|
||||
return listener
|
||||
|
||||
async def close_listener(self):
|
||||
...
|
||||
|
||||
|
||||
class UDSAddress(Address[
|
||||
None,
|
||||
str,
|
||||
trio.SocketStream,
|
||||
trio.SocketListener
|
||||
]):
|
||||
|
||||
name_key: str = 'uds'
|
||||
address_type: type = str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filepath: str
|
||||
):
|
||||
self._filepath = filepath
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def namespace(self) -> None:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def from_addr(cls, filepath: str) -> UDSAddress:
|
||||
return UDSAddress(filepath)
|
||||
|
||||
def unwrap(self) -> str:
|
||||
return self._filepath
|
||||
|
||||
@classmethod
|
||||
def get_random(cls, namespace: None = None) -> UDSAddress:
|
||||
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
|
||||
|
||||
@classmethod
|
||||
def get_root(cls) -> Address:
|
||||
return UDSAddress('tractor.sock')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{type(self)} @ {self._filepath}'
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, UDSAddress):
|
||||
raise TypeError(
|
||||
f'Can not compare {type(other)} with {type(self)}'
|
||||
)
|
||||
|
||||
return self._filepath == other._filepath
|
||||
|
||||
async def open_stream(self, **kwargs) -> trio.SocketStream:
|
||||
stream = await trio.open_unix_socket(
|
||||
self._filepath,
|
||||
**kwargs
|
||||
)
|
||||
return stream
|
||||
|
||||
async def open_listener(self, **kwargs) -> trio.SocketListener:
|
||||
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
await self._sock.bind(self._filepath)
|
||||
self._sock.listen(1)
|
||||
return trio.SocketListener(self._sock)
|
||||
|
||||
async def close_listener(self):
|
||||
self._sock.close()
|
||||
os.unlink(self._filepath)
|
||||
|
||||
|
||||
preferred_transport = 'uds'
|
||||
|
||||
|
||||
_address_types = (
|
||||
TCPAddress,
|
||||
UDSAddress
|
||||
)
|
||||
|
||||
|
||||
_default_addrs: dict[str, Type[Address]] = {
|
||||
cls.name_key: cls
|
||||
for cls in _address_types
|
||||
}
|
||||
|
||||
|
||||
AddressTypes = Union[
|
||||
tuple([
|
||||
cls.address_type
|
||||
for cls in _address_types
|
||||
])
|
||||
]
|
||||
|
||||
|
||||
_default_lo_addrs: dict[
|
||||
str,
|
||||
AddressTypes
|
||||
] = {
|
||||
cls.name_key: cls.get_root().unwrap()
|
||||
for cls in _address_types
|
||||
}
|
||||
|
||||
|
||||
def get_address_cls(name: str) -> Type[Address]:
|
||||
return _default_addrs[name]
|
||||
|
||||
|
||||
def is_wrapped_addr(addr: any) -> bool:
|
||||
return type(addr) in _address_types
|
||||
|
||||
|
||||
def wrap_address(addr: AddressTypes) -> Address:
|
||||
|
||||
if is_wrapped_addr(addr):
|
||||
return addr
|
||||
|
||||
cls = None
|
||||
match addr:
|
||||
case str():
|
||||
cls = UDSAddress
|
||||
|
||||
case tuple() | list():
|
||||
cls = TCPAddress
|
||||
|
||||
case None:
|
||||
cls = get_address_cls(preferred_transport)
|
||||
addr = cls.get_root().unwrap()
|
||||
|
||||
case _:
|
||||
raise TypeError(
|
||||
f'Can not wrap addr {addr} of type {type(addr)}'
|
||||
)
|
||||
|
||||
return cls.from_addr(addr)
|
||||
|
||||
|
||||
def default_lo_addrs(transports: list[str]) -> list[AddressTypes]:
|
||||
return [
|
||||
_default_lo_addrs[transport]
|
||||
for transport in transports
|
||||
]
|
|
@ -31,8 +31,12 @@ def parse_uid(arg):
|
|||
return str(name), str(uuid) # ensures str encoding
|
||||
|
||||
def parse_ipaddr(arg):
|
||||
host, port = literal_eval(arg)
|
||||
return (str(host), int(port))
|
||||
try:
|
||||
return literal_eval(arg)
|
||||
|
||||
except (ValueError, SyntaxError):
|
||||
# UDS: try to interpret as a straight up str
|
||||
return arg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -859,19 +859,10 @@ class Context:
|
|||
@property
|
||||
def dst_maddr(self) -> str:
|
||||
chan: Channel = self.chan
|
||||
dst_addr, dst_port = chan.raddr
|
||||
trans: MsgTransport = chan.transport
|
||||
# cid: str = self.cid
|
||||
# cid_head, cid_tail = cid[:6], cid[-6:]
|
||||
return (
|
||||
f'/ipv4/{dst_addr}'
|
||||
f'/{trans.name_key}/{dst_port}'
|
||||
# f'/{self.chan.uid[0]}'
|
||||
# f'/{self.cid}'
|
||||
|
||||
# f'/cid={cid_head}..{cid_tail}'
|
||||
# TODO: ? not use this ^ right ?
|
||||
)
|
||||
return trans.maddr
|
||||
|
||||
dmaddr = dst_maddr
|
||||
|
||||
|
|
|
@ -30,6 +30,12 @@ from contextlib import asynccontextmanager as acm
|
|||
from tractor.log import get_logger
|
||||
from .trionics import gather_contexts
|
||||
from .ipc import _connect_chan, Channel
|
||||
from ._addr import (
|
||||
AddressTypes,
|
||||
Address,
|
||||
preferred_transport,
|
||||
wrap_address
|
||||
)
|
||||
from ._portal import (
|
||||
Portal,
|
||||
open_portal,
|
||||
|
@ -48,11 +54,7 @@ log = get_logger(__name__)
|
|||
|
||||
|
||||
@acm
|
||||
async def get_registry(
|
||||
host: str,
|
||||
port: int,
|
||||
|
||||
) -> AsyncGenerator[
|
||||
async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
|
||||
Portal | LocalPortal | None,
|
||||
None,
|
||||
]:
|
||||
|
@ -69,13 +71,13 @@ async def get_registry(
|
|||
# (likely a re-entrant call from the arbiter actor)
|
||||
yield LocalPortal(
|
||||
actor,
|
||||
Channel((host, port))
|
||||
await Channel.from_addr(addr)
|
||||
)
|
||||
else:
|
||||
# TODO: try to look pre-existing connection from
|
||||
# `Actor._peers` and use it instead?
|
||||
async with (
|
||||
_connect_chan(host, port) as chan,
|
||||
_connect_chan(addr) as chan,
|
||||
open_portal(chan) as regstr_ptl,
|
||||
):
|
||||
yield regstr_ptl
|
||||
|
@ -89,11 +91,10 @@ async def get_root(
|
|||
|
||||
# TODO: rename mailbox to `_root_maddr` when we finally
|
||||
# add and impl libp2p multi-addrs?
|
||||
host, port = _runtime_vars['_root_mailbox']
|
||||
assert host is not None
|
||||
addr = _runtime_vars['_root_mailbox']
|
||||
|
||||
async with (
|
||||
_connect_chan(host, port) as chan,
|
||||
_connect_chan(addr) as chan,
|
||||
open_portal(chan, **kwargs) as portal,
|
||||
):
|
||||
yield portal
|
||||
|
@ -134,10 +135,10 @@ def get_peer_by_name(
|
|||
@acm
|
||||
async def query_actor(
|
||||
name: str,
|
||||
regaddr: tuple[str, int]|None = None,
|
||||
regaddr: AddressTypes|None = None,
|
||||
|
||||
) -> AsyncGenerator[
|
||||
tuple[str, int]|None,
|
||||
AddressTypes|None,
|
||||
None,
|
||||
]:
|
||||
'''
|
||||
|
@ -163,31 +164,31 @@ async def query_actor(
|
|||
return
|
||||
|
||||
reg_portal: Portal
|
||||
regaddr: tuple[str, int] = regaddr or actor.reg_addrs[0]
|
||||
async with get_registry(*regaddr) as reg_portal:
|
||||
regaddr: Address = wrap_address(regaddr) or actor.reg_addrs[0]
|
||||
async with get_registry(regaddr) as reg_portal:
|
||||
# TODO: return portals to all available actors - for now
|
||||
# just the last one that registered
|
||||
sockaddr: tuple[str, int] = await reg_portal.run_from_ns(
|
||||
addr: AddressTypes = await reg_portal.run_from_ns(
|
||||
'self',
|
||||
'find_actor',
|
||||
name=name,
|
||||
)
|
||||
yield sockaddr
|
||||
yield addr
|
||||
|
||||
|
||||
@acm
|
||||
async def maybe_open_portal(
|
||||
addr: tuple[str, int],
|
||||
addr: AddressTypes,
|
||||
name: str,
|
||||
):
|
||||
async with query_actor(
|
||||
name=name,
|
||||
regaddr=addr,
|
||||
) as sockaddr:
|
||||
) as addr:
|
||||
pass
|
||||
|
||||
if sockaddr:
|
||||
async with _connect_chan(*sockaddr) as chan:
|
||||
if addr:
|
||||
async with _connect_chan(addr) as chan:
|
||||
async with open_portal(chan) as portal:
|
||||
yield portal
|
||||
else:
|
||||
|
@ -197,7 +198,8 @@ async def maybe_open_portal(
|
|||
@acm
|
||||
async def find_actor(
|
||||
name: str,
|
||||
registry_addrs: list[tuple[str, int]]|None = None,
|
||||
registry_addrs: list[AddressTypes]|None = None,
|
||||
enable_transports: list[str] = [preferred_transport],
|
||||
|
||||
only_first: bool = True,
|
||||
raise_on_none: bool = False,
|
||||
|
@ -224,15 +226,15 @@ async def find_actor(
|
|||
# XXX NOTE: make sure to dynamically read the value on
|
||||
# every call since something may change it globally (eg.
|
||||
# like in our discovery test suite)!
|
||||
from . import _root
|
||||
from ._addr import default_lo_addrs
|
||||
registry_addrs = (
|
||||
_runtime_vars['_registry_addrs']
|
||||
or
|
||||
_root._default_lo_addrs
|
||||
default_lo_addrs(enable_transports)
|
||||
)
|
||||
|
||||
maybe_portals: list[
|
||||
AsyncContextManager[tuple[str, int]]
|
||||
AsyncContextManager[AddressTypes]
|
||||
] = list(
|
||||
maybe_open_portal(
|
||||
addr=addr,
|
||||
|
@ -274,7 +276,7 @@ async def find_actor(
|
|||
@acm
|
||||
async def wait_for_actor(
|
||||
name: str,
|
||||
registry_addr: tuple[str, int] | None = None,
|
||||
registry_addr: AddressTypes | None = None,
|
||||
|
||||
) -> AsyncGenerator[Portal, None]:
|
||||
'''
|
||||
|
@ -291,7 +293,7 @@ async def wait_for_actor(
|
|||
yield peer_portal
|
||||
return
|
||||
|
||||
regaddr: tuple[str, int] = (
|
||||
regaddr: AddressTypes = (
|
||||
registry_addr
|
||||
or
|
||||
actor.reg_addrs[0]
|
||||
|
@ -299,8 +301,8 @@ async def wait_for_actor(
|
|||
# TODO: use `.trionics.gather_contexts()` like
|
||||
# above in `find_actor()` as well?
|
||||
reg_portal: Portal
|
||||
async with get_registry(*regaddr) as reg_portal:
|
||||
sockaddrs = await reg_portal.run_from_ns(
|
||||
async with get_registry(regaddr) as reg_portal:
|
||||
addrs = await reg_portal.run_from_ns(
|
||||
'self',
|
||||
'wait_for_actor',
|
||||
name=name,
|
||||
|
@ -308,8 +310,8 @@ async def wait_for_actor(
|
|||
|
||||
# get latest registered addr by default?
|
||||
# TODO: offer multi-portal yields in multi-homed case?
|
||||
sockaddr: tuple[str, int] = sockaddrs[-1]
|
||||
addr: AddressTypes = addrs[-1]
|
||||
|
||||
async with _connect_chan(*sockaddr) as chan:
|
||||
async with _connect_chan(addr) as chan:
|
||||
async with open_portal(chan) as portal:
|
||||
yield portal
|
||||
|
|
|
@ -37,6 +37,7 @@ from .log import (
|
|||
from . import _state
|
||||
from .devx import _debug
|
||||
from .to_asyncio import run_as_asyncio_guest
|
||||
from ._addr import AddressTypes
|
||||
from ._runtime import (
|
||||
async_main,
|
||||
Actor,
|
||||
|
@ -52,10 +53,10 @@ log = get_logger(__name__)
|
|||
def _mp_main(
|
||||
|
||||
actor: Actor,
|
||||
accept_addrs: list[tuple[str, int]],
|
||||
accept_addrs: list[AddressTypes],
|
||||
forkserver_info: tuple[Any, Any, Any, Any, Any],
|
||||
start_method: SpawnMethodKey,
|
||||
parent_addr: tuple[str, int] | None = None,
|
||||
parent_addr: AddressTypes | None = None,
|
||||
infect_asyncio: bool = False,
|
||||
|
||||
) -> None:
|
||||
|
@ -206,7 +207,7 @@ def nest_from_op(
|
|||
def _trio_main(
|
||||
actor: Actor,
|
||||
*,
|
||||
parent_addr: tuple[str, int] | None = None,
|
||||
parent_addr: AddressTypes | None = None,
|
||||
infect_asyncio: bool = False,
|
||||
|
||||
) -> None:
|
||||
|
|
|
@ -43,21 +43,18 @@ from .devx import _debug
|
|||
from . import _spawn
|
||||
from . import _state
|
||||
from . import log
|
||||
from .ipc import _connect_chan
|
||||
from .ipc import (
|
||||
_connect_chan,
|
||||
)
|
||||
from ._addr import (
|
||||
AddressTypes,
|
||||
wrap_address,
|
||||
preferred_transport,
|
||||
default_lo_addrs
|
||||
)
|
||||
from ._exceptions import is_multi_cancelled
|
||||
|
||||
|
||||
# set at startup and after forks
|
||||
_default_host: str = '127.0.0.1'
|
||||
_default_port: int = 1616
|
||||
|
||||
# default registry always on localhost
|
||||
_default_lo_addrs: list[tuple[str, int]] = [(
|
||||
_default_host,
|
||||
_default_port,
|
||||
)]
|
||||
|
||||
|
||||
logger = log.get_logger('tractor')
|
||||
|
||||
|
||||
|
@ -66,10 +63,12 @@ async def open_root_actor(
|
|||
|
||||
*,
|
||||
# defaults are above
|
||||
registry_addrs: list[tuple[str, int]]|None = None,
|
||||
registry_addrs: list[AddressTypes]|None = None,
|
||||
|
||||
# defaults are above
|
||||
arbiter_addr: tuple[str, int]|None = None,
|
||||
arbiter_addr: tuple[AddressTypes]|None = None,
|
||||
|
||||
enable_transports: list[str] = [preferred_transport],
|
||||
|
||||
name: str|None = 'root',
|
||||
|
||||
|
@ -195,11 +194,9 @@ async def open_root_actor(
|
|||
)
|
||||
registry_addrs = [arbiter_addr]
|
||||
|
||||
registry_addrs: list[tuple[str, int]] = (
|
||||
registry_addrs
|
||||
or
|
||||
_default_lo_addrs
|
||||
)
|
||||
if not registry_addrs:
|
||||
registry_addrs: list[AddressTypes] = default_lo_addrs(enable_transports)
|
||||
|
||||
assert registry_addrs
|
||||
|
||||
loglevel = (
|
||||
|
@ -248,10 +245,10 @@ async def open_root_actor(
|
|||
enable_stack_on_sig()
|
||||
|
||||
# closed into below ping task-func
|
||||
ponged_addrs: list[tuple[str, int]] = []
|
||||
ponged_addrs: list[AddressTypes] = []
|
||||
|
||||
async def ping_tpt_socket(
|
||||
addr: tuple[str, int],
|
||||
addr: AddressTypes,
|
||||
timeout: float = 1,
|
||||
) -> None:
|
||||
'''
|
||||
|
@ -271,7 +268,7 @@ async def open_root_actor(
|
|||
# be better to eventually have a "discovery" protocol
|
||||
# with basic handshake instead?
|
||||
with trio.move_on_after(timeout):
|
||||
async with _connect_chan(*addr):
|
||||
async with _connect_chan(addr):
|
||||
ponged_addrs.append(addr)
|
||||
|
||||
except OSError:
|
||||
|
@ -284,10 +281,10 @@ async def open_root_actor(
|
|||
for addr in registry_addrs:
|
||||
tn.start_soon(
|
||||
ping_tpt_socket,
|
||||
tuple(addr), # TODO: just drop this requirement?
|
||||
addr,
|
||||
)
|
||||
|
||||
trans_bind_addrs: list[tuple[str, int]] = []
|
||||
trans_bind_addrs: list[AddressTypes] = []
|
||||
|
||||
# Create a new local root-actor instance which IS NOT THE
|
||||
# REGISTRAR
|
||||
|
@ -311,9 +308,12 @@ async def open_root_actor(
|
|||
)
|
||||
# DO NOT use the registry_addrs as the transport server
|
||||
# addrs for this new non-registar, root-actor.
|
||||
for host, port in ponged_addrs:
|
||||
# NOTE: zero triggers dynamic OS port allocation
|
||||
trans_bind_addrs.append((host, 0))
|
||||
for addr in ponged_addrs:
|
||||
waddr = wrap_address(addr)
|
||||
print(waddr)
|
||||
trans_bind_addrs.append(
|
||||
waddr.get_random(namespace=waddr.namespace)
|
||||
)
|
||||
|
||||
# Start this local actor as the "registrar", aka a regular
|
||||
# actor who manages the local registry of "mailboxes" of
|
||||
|
@ -322,7 +322,7 @@ async def open_root_actor(
|
|||
|
||||
# NOTE that if the current actor IS THE REGISTAR, the
|
||||
# following init steps are taken:
|
||||
# - the tranport layer server is bound to each (host, port)
|
||||
# - the tranport layer server is bound to each addr
|
||||
# pair defined in provided registry_addrs, or the default.
|
||||
trans_bind_addrs = registry_addrs
|
||||
|
||||
|
@ -462,7 +462,7 @@ def run_daemon(
|
|||
|
||||
# runtime kwargs
|
||||
name: str | None = 'root',
|
||||
registry_addrs: list[tuple[str, int]] = _default_lo_addrs,
|
||||
registry_addrs: list[AddressTypes]|None = None,
|
||||
|
||||
start_method: str | None = None,
|
||||
debug_mode: bool = False,
|
||||
|
|
|
@ -74,6 +74,13 @@ from tractor.msg import (
|
|||
types as msgtypes,
|
||||
)
|
||||
from .ipc import Channel
|
||||
from ._addr import (
|
||||
AddressTypes,
|
||||
Address,
|
||||
wrap_address,
|
||||
preferred_transport,
|
||||
default_lo_addrs
|
||||
)
|
||||
from ._context import (
|
||||
mk_context,
|
||||
Context,
|
||||
|
@ -179,11 +186,11 @@ class Actor:
|
|||
enable_modules: list[str] = [],
|
||||
uid: str|None = None,
|
||||
loglevel: str|None = None,
|
||||
registry_addrs: list[tuple[str, int]]|None = None,
|
||||
registry_addrs: list[AddressTypes]|None = None,
|
||||
spawn_method: str|None = None,
|
||||
|
||||
# TODO: remove!
|
||||
arbiter_addr: tuple[str, int]|None = None,
|
||||
arbiter_addr: AddressTypes|None = None,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
|
@ -223,7 +230,7 @@ class Actor:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
registry_addrs: list[tuple[str, int]] = [arbiter_addr]
|
||||
registry_addrs: list[AddressTypes] = [arbiter_addr]
|
||||
|
||||
# marked by the process spawning backend at startup
|
||||
# will be None for the parent most process started manually
|
||||
|
@ -257,6 +264,7 @@ class Actor:
|
|||
] = {}
|
||||
|
||||
self._listeners: list[trio.abc.Listener] = []
|
||||
self._listen_addrs: list[Address] = []
|
||||
self._parent_chan: Channel|None = None
|
||||
self._forkserver_info: tuple|None = None
|
||||
|
||||
|
@ -269,13 +277,13 @@ class Actor:
|
|||
|
||||
# when provided, init the registry addresses property from
|
||||
# input via the validator.
|
||||
self._reg_addrs: list[tuple[str, int]] = []
|
||||
self._reg_addrs: list[AddressTypes] = []
|
||||
if registry_addrs:
|
||||
self.reg_addrs: list[tuple[str, int]] = registry_addrs
|
||||
self.reg_addrs: list[AddressTypes] = registry_addrs
|
||||
_state._runtime_vars['_registry_addrs'] = registry_addrs
|
||||
|
||||
@property
|
||||
def reg_addrs(self) -> list[tuple[str, int]]:
|
||||
def reg_addrs(self) -> list[AddressTypes]:
|
||||
'''
|
||||
List of (socket) addresses for all known (and contactable)
|
||||
registry actors.
|
||||
|
@ -286,7 +294,7 @@ class Actor:
|
|||
@reg_addrs.setter
|
||||
def reg_addrs(
|
||||
self,
|
||||
addrs: list[tuple[str, int]],
|
||||
addrs: list[AddressTypes],
|
||||
) -> None:
|
||||
if not addrs:
|
||||
log.warning(
|
||||
|
@ -295,16 +303,7 @@ class Actor:
|
|||
)
|
||||
return
|
||||
|
||||
# always sanity check the input list since it's critical
|
||||
# that addrs are correct for discovery sys operation.
|
||||
for addr in addrs:
|
||||
if not isinstance(addr, tuple):
|
||||
raise ValueError(
|
||||
'Expected `Actor.reg_addrs: list[tuple[str, int]]`\n'
|
||||
f'Got {addrs}'
|
||||
)
|
||||
|
||||
self._reg_addrs = addrs
|
||||
self._reg_addrs = addrs
|
||||
|
||||
async def wait_for_peer(
|
||||
self,
|
||||
|
@ -1024,11 +1023,11 @@ class Actor:
|
|||
|
||||
async def _from_parent(
|
||||
self,
|
||||
parent_addr: tuple[str, int]|None,
|
||||
parent_addr: AddressTypes|None,
|
||||
|
||||
) -> tuple[
|
||||
Channel,
|
||||
list[tuple[str, int]]|None,
|
||||
list[AddressTypes]|None,
|
||||
]:
|
||||
'''
|
||||
Bootstrap this local actor's runtime config from its parent by
|
||||
|
@ -1040,16 +1039,13 @@ class Actor:
|
|||
# Connect back to the parent actor and conduct initial
|
||||
# handshake. From this point on if we error, we
|
||||
# attempt to ship the exception back to the parent.
|
||||
chan = Channel(
|
||||
destaddr=parent_addr,
|
||||
)
|
||||
await chan.connect()
|
||||
chan = await Channel.from_addr(wrap_address(parent_addr))
|
||||
|
||||
# TODO: move this into a `Channel.handshake()`?
|
||||
# Initial handshake: swap names.
|
||||
await self._do_handshake(chan)
|
||||
|
||||
accept_addrs: list[tuple[str, int]]|None = None
|
||||
accept_addrs: list[AddressTypes]|None = None
|
||||
|
||||
if self._spawn_method == "trio":
|
||||
|
||||
|
@ -1066,7 +1062,7 @@ class Actor:
|
|||
# if "trace"/"util" mode is enabled?
|
||||
f'{pretty_struct.pformat(spawnspec)}\n'
|
||||
)
|
||||
accept_addrs: list[tuple[str, int]] = spawnspec.bind_addrs
|
||||
accept_addrs: list[AddressTypes] = spawnspec.bind_addrs
|
||||
|
||||
# TODO: another `Struct` for rtvs..
|
||||
rvs: dict[str, Any] = spawnspec._runtime_vars
|
||||
|
@ -1173,8 +1169,7 @@ class Actor:
|
|||
self,
|
||||
handler_nursery: Nursery,
|
||||
*,
|
||||
# (host, port) to bind for channel server
|
||||
listen_sockaddrs: list[tuple[str, int]]|None = None,
|
||||
listen_addrs: list[AddressTypes]|None = None,
|
||||
|
||||
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
|
@ -1186,41 +1181,45 @@ class Actor:
|
|||
`.cancel_server()` is called.
|
||||
|
||||
'''
|
||||
if listen_sockaddrs is None:
|
||||
listen_sockaddrs = [(None, 0)]
|
||||
if listen_addrs is None:
|
||||
listen_addrs = default_lo_addrs([preferred_transport])
|
||||
|
||||
else:
|
||||
listen_addrs: list[Address] = [
|
||||
wrap_address(a) for a in listen_addrs
|
||||
]
|
||||
|
||||
self._server_down = trio.Event()
|
||||
try:
|
||||
async with trio.open_nursery() as server_n:
|
||||
listeners: list[trio.abc.Listener] = [
|
||||
await addr.open_listener()
|
||||
for addr in listen_addrs
|
||||
]
|
||||
await server_n.start(
|
||||
partial(
|
||||
trio.serve_listeners,
|
||||
handler=self._stream_handler,
|
||||
listeners=listeners,
|
||||
|
||||
for host, port in listen_sockaddrs:
|
||||
listeners: list[trio.abc.Listener] = await server_n.start(
|
||||
partial(
|
||||
trio.serve_tcp,
|
||||
|
||||
handler=self._stream_handler,
|
||||
port=port,
|
||||
host=host,
|
||||
|
||||
# NOTE: configured such that new
|
||||
# connections will stay alive even if
|
||||
# this server is cancelled!
|
||||
handler_nursery=handler_nursery,
|
||||
)
|
||||
# NOTE: configured such that new
|
||||
# connections will stay alive even if
|
||||
# this server is cancelled!
|
||||
handler_nursery=handler_nursery
|
||||
)
|
||||
sockets: list[trio.socket] = [
|
||||
getattr(listener, 'socket', 'unknown socket')
|
||||
for listener in listeners
|
||||
]
|
||||
log.runtime(
|
||||
'Started TCP server(s)\n'
|
||||
f'|_{sockets}\n'
|
||||
)
|
||||
self._listeners.extend(listeners)
|
||||
)
|
||||
log.runtime(
|
||||
'Started server(s)\n'
|
||||
'\n'.join([f'|_{addr}' for addr in listen_addrs])
|
||||
)
|
||||
self._listen_addrs.extend(listen_addrs)
|
||||
self._listeners.extend(listeners)
|
||||
|
||||
task_status.started(server_n)
|
||||
|
||||
finally:
|
||||
for addr in listen_addrs:
|
||||
await addr.close_listener()
|
||||
# signal the server is down since nursery above terminated
|
||||
self._server_down.set()
|
||||
|
||||
|
@ -1579,26 +1578,21 @@ class Actor:
|
|||
return False
|
||||
|
||||
@property
|
||||
def accept_addrs(self) -> list[tuple[str, int]]:
|
||||
def accept_addrs(self) -> list[AddressTypes]:
|
||||
'''
|
||||
All addresses to which the transport-channel server binds
|
||||
and listens for new connections.
|
||||
|
||||
'''
|
||||
# throws OSError on failure
|
||||
return [
|
||||
listener.socket.getsockname()
|
||||
for listener in self._listeners
|
||||
] # type: ignore
|
||||
return [a.unwrap() for a in self._listen_addrs]
|
||||
|
||||
@property
|
||||
def accept_addr(self) -> tuple[str, int]:
|
||||
def accept_addr(self) -> AddressTypes:
|
||||
'''
|
||||
Primary address to which the IPC transport server is
|
||||
bound and listening for new connections.
|
||||
|
||||
'''
|
||||
# throws OSError on failure
|
||||
return self.accept_addrs[0]
|
||||
|
||||
def get_parent(self) -> Portal:
|
||||
|
@ -1670,7 +1664,7 @@ class Actor:
|
|||
|
||||
async def async_main(
|
||||
actor: Actor,
|
||||
accept_addrs: tuple[str, int]|None = None,
|
||||
accept_addrs: AddressTypes|None = None,
|
||||
|
||||
# XXX: currently ``parent_addr`` is only needed for the
|
||||
# ``multiprocessing`` backend (which pickles state sent to
|
||||
|
@ -1679,7 +1673,7 @@ async def async_main(
|
|||
# change this to a simple ``is_subactor: bool`` which will
|
||||
# be False when running as root actor and True when as
|
||||
# a subactor.
|
||||
parent_addr: tuple[str, int]|None = None,
|
||||
parent_addr: AddressTypes|None = None,
|
||||
task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED,
|
||||
|
||||
) -> None:
|
||||
|
@ -1769,7 +1763,7 @@ async def async_main(
|
|||
partial(
|
||||
actor._serve_forever,
|
||||
service_nursery,
|
||||
listen_sockaddrs=accept_addrs,
|
||||
listen_addrs=accept_addrs,
|
||||
)
|
||||
)
|
||||
except OSError as oserr:
|
||||
|
@ -1785,7 +1779,7 @@ async def async_main(
|
|||
|
||||
raise
|
||||
|
||||
accept_addrs: list[tuple[str, int]] = actor.accept_addrs
|
||||
accept_addrs: list[AddressTypes] = actor.accept_addrs
|
||||
|
||||
# NOTE: only set the loopback addr for the
|
||||
# process-tree-global "root" mailbox since
|
||||
|
@ -1793,9 +1787,8 @@ async def async_main(
|
|||
# their root actor over that channel.
|
||||
if _state._runtime_vars['_is_root']:
|
||||
for addr in accept_addrs:
|
||||
host, _ = addr
|
||||
# TODO: generic 'lo' detector predicate
|
||||
if '127.0.0.1' in host:
|
||||
waddr = wrap_address(addr)
|
||||
if waddr == waddr.get_root():
|
||||
_state._runtime_vars['_root_mailbox'] = addr
|
||||
|
||||
# Register with the arbiter if we're told its addr
|
||||
|
@ -1810,24 +1803,21 @@ async def async_main(
|
|||
# only on unique actor uids?
|
||||
for addr in actor.reg_addrs:
|
||||
try:
|
||||
assert isinstance(addr, tuple)
|
||||
assert addr[1] # non-zero after bind
|
||||
waddr = wrap_address(addr)
|
||||
assert waddr.is_valid
|
||||
except AssertionError:
|
||||
await _debug.pause()
|
||||
|
||||
async with get_registry(*addr) as reg_portal:
|
||||
async with get_registry(addr) as reg_portal:
|
||||
for accept_addr in accept_addrs:
|
||||
|
||||
if not accept_addr[1]:
|
||||
await _debug.pause()
|
||||
|
||||
assert accept_addr[1]
|
||||
accept_addr = wrap_address(accept_addr)
|
||||
assert accept_addr.is_valid
|
||||
|
||||
await reg_portal.run_from_ns(
|
||||
'self',
|
||||
'register_actor',
|
||||
uid=actor.uid,
|
||||
sockaddr=accept_addr,
|
||||
addr=accept_addr.unwrap(),
|
||||
)
|
||||
|
||||
is_registered: bool = True
|
||||
|
@ -1954,12 +1944,13 @@ async def async_main(
|
|||
):
|
||||
failed: bool = False
|
||||
for addr in actor.reg_addrs:
|
||||
assert isinstance(addr, tuple)
|
||||
waddr = wrap_address(addr)
|
||||
assert waddr.is_valid
|
||||
with trio.move_on_after(0.5) as cs:
|
||||
cs.shield = True
|
||||
try:
|
||||
async with get_registry(
|
||||
*addr,
|
||||
addr,
|
||||
) as reg_portal:
|
||||
await reg_portal.run_from_ns(
|
||||
'self',
|
||||
|
@ -2037,7 +2028,7 @@ class Arbiter(Actor):
|
|||
|
||||
self._registry: dict[
|
||||
tuple[str, str],
|
||||
tuple[str, int],
|
||||
AddressTypes,
|
||||
] = {}
|
||||
self._waiters: dict[
|
||||
str,
|
||||
|
@ -2053,18 +2044,18 @@ class Arbiter(Actor):
|
|||
self,
|
||||
name: str,
|
||||
|
||||
) -> tuple[str, int]|None:
|
||||
) -> AddressTypes|None:
|
||||
|
||||
for uid, sockaddr in self._registry.items():
|
||||
for uid, addr in self._registry.items():
|
||||
if name in uid:
|
||||
return sockaddr
|
||||
return addr
|
||||
|
||||
return None
|
||||
|
||||
async def get_registry(
|
||||
self
|
||||
|
||||
) -> dict[str, tuple[str, int]]:
|
||||
) -> dict[str, AddressTypes]:
|
||||
'''
|
||||
Return current name registry.
|
||||
|
||||
|
@ -2084,7 +2075,7 @@ class Arbiter(Actor):
|
|||
self,
|
||||
name: str,
|
||||
|
||||
) -> list[tuple[str, int]]:
|
||||
) -> list[AddressTypes]:
|
||||
'''
|
||||
Wait for a particular actor to register.
|
||||
|
||||
|
@ -2092,44 +2083,41 @@ class Arbiter(Actor):
|
|||
registered.
|
||||
|
||||
'''
|
||||
sockaddrs: list[tuple[str, int]] = []
|
||||
sockaddr: tuple[str, int]
|
||||
addrs: list[AddressTypes] = []
|
||||
addr: AddressTypes
|
||||
|
||||
mailbox_info: str = 'Actor registry contact infos:\n'
|
||||
for uid, sockaddr in self._registry.items():
|
||||
for uid, addr in self._registry.items():
|
||||
mailbox_info += (
|
||||
f'|_uid: {uid}\n'
|
||||
f'|_sockaddr: {sockaddr}\n\n'
|
||||
f'|_addr: {addr}\n\n'
|
||||
)
|
||||
if name == uid[0]:
|
||||
sockaddrs.append(sockaddr)
|
||||
addrs.append(addr)
|
||||
|
||||
if not sockaddrs:
|
||||
if not addrs:
|
||||
waiter = trio.Event()
|
||||
self._waiters.setdefault(name, []).append(waiter)
|
||||
await waiter.wait()
|
||||
|
||||
for uid in self._waiters[name]:
|
||||
if not isinstance(uid, trio.Event):
|
||||
sockaddrs.append(self._registry[uid])
|
||||
addrs.append(self._registry[uid])
|
||||
|
||||
log.runtime(mailbox_info)
|
||||
return sockaddrs
|
||||
return addrs
|
||||
|
||||
async def register_actor(
|
||||
self,
|
||||
uid: tuple[str, str],
|
||||
sockaddr: tuple[str, int]
|
||||
|
||||
addr: AddressTypes
|
||||
) -> None:
|
||||
uid = name, hash = (str(uid[0]), str(uid[1]))
|
||||
addr = (host, port) = (
|
||||
str(sockaddr[0]),
|
||||
int(sockaddr[1]),
|
||||
)
|
||||
if port == 0:
|
||||
waddr: Address = wrap_address(addr)
|
||||
if not waddr.is_valid:
|
||||
# should never be 0-dynamic-os-alloc
|
||||
await _debug.pause()
|
||||
assert port # should never be 0-dynamic-os-alloc
|
||||
|
||||
self._registry[uid] = addr
|
||||
|
||||
# pop and signal all waiter events
|
||||
|
|
|
@ -46,6 +46,7 @@ from tractor._state import (
|
|||
_runtime_vars,
|
||||
)
|
||||
from tractor.log import get_logger
|
||||
from tractor._addr import AddressTypes
|
||||
from tractor._portal import Portal
|
||||
from tractor._runtime import Actor
|
||||
from tractor._entry import _mp_main
|
||||
|
@ -392,8 +393,8 @@ async def new_proc(
|
|||
errors: dict[tuple[str, str], Exception],
|
||||
|
||||
# passed through to actor main
|
||||
bind_addrs: list[tuple[str, int]],
|
||||
parent_addr: tuple[str, int],
|
||||
bind_addrs: list[AddressTypes],
|
||||
parent_addr: AddressTypes,
|
||||
_runtime_vars: dict[str, Any], # serialized and sent to _child
|
||||
|
||||
*,
|
||||
|
@ -431,8 +432,8 @@ async def trio_proc(
|
|||
errors: dict[tuple[str, str], Exception],
|
||||
|
||||
# passed through to actor main
|
||||
bind_addrs: list[tuple[str, int]],
|
||||
parent_addr: tuple[str, int],
|
||||
bind_addrs: list[AddressTypes],
|
||||
parent_addr: AddressTypes,
|
||||
_runtime_vars: dict[str, Any], # serialized and sent to _child
|
||||
*,
|
||||
infect_asyncio: bool = False,
|
||||
|
@ -520,15 +521,15 @@ async def trio_proc(
|
|||
|
||||
# send a "spawning specification" which configures the
|
||||
# initial runtime state of the child.
|
||||
await chan.send(
|
||||
SpawnSpec(
|
||||
_parent_main_data=subactor._parent_main_data,
|
||||
enable_modules=subactor.enable_modules,
|
||||
reg_addrs=subactor.reg_addrs,
|
||||
bind_addrs=bind_addrs,
|
||||
_runtime_vars=_runtime_vars,
|
||||
)
|
||||
sspec = SpawnSpec(
|
||||
_parent_main_data=subactor._parent_main_data,
|
||||
enable_modules=subactor.enable_modules,
|
||||
reg_addrs=subactor.reg_addrs,
|
||||
bind_addrs=bind_addrs,
|
||||
_runtime_vars=_runtime_vars,
|
||||
)
|
||||
log.runtime(f'Sending spawn spec: {str(sspec)}')
|
||||
await chan.send(sspec)
|
||||
|
||||
# track subactor in current nursery
|
||||
curr_actor: Actor = current_actor()
|
||||
|
@ -638,8 +639,8 @@ async def mp_proc(
|
|||
subactor: Actor,
|
||||
errors: dict[tuple[str, str], Exception],
|
||||
# passed through to actor main
|
||||
bind_addrs: list[tuple[str, int]],
|
||||
parent_addr: tuple[str, int],
|
||||
bind_addrs: list[AddressTypes],
|
||||
parent_addr: AddressTypes,
|
||||
_runtime_vars: dict[str, Any], # serialized and sent to _child
|
||||
*,
|
||||
infect_asyncio: bool = False,
|
||||
|
|
|
@ -28,7 +28,13 @@ import warnings
|
|||
|
||||
import trio
|
||||
|
||||
|
||||
from .devx._debug import maybe_wait_for_debugger
|
||||
from ._addr import (
|
||||
AddressTypes,
|
||||
preferred_transport,
|
||||
get_address_cls
|
||||
)
|
||||
from ._state import current_actor, is_main_process
|
||||
from .log import get_logger, get_loglevel
|
||||
from ._runtime import Actor
|
||||
|
@ -47,8 +53,6 @@ if TYPE_CHECKING:
|
|||
|
||||
log = get_logger(__name__)
|
||||
|
||||
_default_bind_addr: tuple[str, int] = ('127.0.0.1', 0)
|
||||
|
||||
|
||||
class ActorNursery:
|
||||
'''
|
||||
|
@ -130,8 +134,9 @@ class ActorNursery:
|
|||
|
||||
*,
|
||||
|
||||
bind_addrs: list[tuple[str, int]] = [_default_bind_addr],
|
||||
bind_addrs: list[AddressTypes]|None = None,
|
||||
rpc_module_paths: list[str]|None = None,
|
||||
enable_transports: list[str] = [preferred_transport],
|
||||
enable_modules: list[str]|None = None,
|
||||
loglevel: str|None = None, # set log level per subactor
|
||||
debug_mode: bool|None = None,
|
||||
|
@ -156,6 +161,12 @@ class ActorNursery:
|
|||
or get_loglevel()
|
||||
)
|
||||
|
||||
if not bind_addrs:
|
||||
bind_addrs: list[AddressTypes] = [
|
||||
get_address_cls(transport).get_random().unwrap()
|
||||
for transport in enable_transports
|
||||
]
|
||||
|
||||
# configure and pass runtime state
|
||||
_rtv = _state._runtime_vars.copy()
|
||||
_rtv['_is_root'] = False
|
||||
|
@ -224,7 +235,7 @@ class ActorNursery:
|
|||
*,
|
||||
|
||||
name: str | None = None,
|
||||
bind_addrs: tuple[str, int] = [_default_bind_addr],
|
||||
bind_addrs: AddressTypes|None = None,
|
||||
rpc_module_paths: list[str] | None = None,
|
||||
enable_modules: list[str] | None = None,
|
||||
loglevel: str | None = None, # set log level per subactor
|
||||
|
|
|
@ -2,19 +2,59 @@ import os
|
|||
import random
|
||||
|
||||
|
||||
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||
'''
|
||||
Generate a byte instance of len `amount` with:
|
||||
|
||||
```
|
||||
byte_at_index(i) = (i % 10).encode()
|
||||
```
|
||||
|
||||
this results in constantly repeating sequences of:
|
||||
|
||||
b'0123456789'
|
||||
|
||||
'''
|
||||
return b''.join(str(i % 10).encode() for i in range(amount))
|
||||
|
||||
|
||||
def generate_sample_messages(
|
||||
amount: int,
|
||||
rand_min: int = 0,
|
||||
rand_max: int = 0,
|
||||
silent: bool = False
|
||||
silent: bool = False,
|
||||
) -> tuple[list[bytes], int]:
|
||||
'''
|
||||
Generate bytes msgs for tests.
|
||||
|
||||
Messages will have the following format:
|
||||
|
||||
```
|
||||
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
|
||||
```
|
||||
|
||||
so for message index 25:
|
||||
|
||||
b'[00000025]' + random_bytes
|
||||
|
||||
'''
|
||||
msgs = []
|
||||
size = 0
|
||||
|
||||
log_interval = None
|
||||
if not silent:
|
||||
print(f'\ngenerating {amount} messages...')
|
||||
|
||||
# calculate an apropiate log interval based on
|
||||
# max message size
|
||||
max_msg_size = 10 + rand_max
|
||||
|
||||
if max_msg_size <= 32 * 1024:
|
||||
log_interval = 10_000
|
||||
|
||||
else:
|
||||
log_interval = 1000
|
||||
|
||||
for i in range(amount):
|
||||
msg = f'[{i:08}]'.encode('utf-8')
|
||||
|
||||
|
@ -26,7 +66,13 @@ def generate_sample_messages(
|
|||
|
||||
msgs.append(msg)
|
||||
|
||||
if not silent and i and i % 10_000 == 0:
|
||||
if (
|
||||
not silent
|
||||
and
|
||||
i > 0
|
||||
and
|
||||
i % log_interval == 0
|
||||
):
|
||||
print(f'{i} generated')
|
||||
|
||||
if not silent:
|
||||
|
|
|
@ -13,20 +13,25 @@
|
|||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
import platform
|
||||
|
||||
from ._transport import MsgTransport as MsgTransport
|
||||
from ._transport import (
|
||||
MsgTransportKey as MsgTransportKey,
|
||||
MsgType as MsgType,
|
||||
MsgTransport as MsgTransport,
|
||||
MsgpackTransport as MsgpackTransport
|
||||
)
|
||||
|
||||
from ._tcp import (
|
||||
get_stream_addrs as get_stream_addrs,
|
||||
MsgpackTCPStream as MsgpackTCPStream
|
||||
from ._tcp import MsgpackTCPStream as MsgpackTCPStream
|
||||
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
||||
|
||||
from ._types import (
|
||||
transport_from_addr as transport_from_addr,
|
||||
transport_from_stream as transport_from_stream,
|
||||
)
|
||||
|
||||
from ._chan import (
|
||||
_connect_chan as _connect_chan,
|
||||
get_msg_transport as get_msg_transport,
|
||||
Channel as Channel
|
||||
)
|
||||
|
||||
|
@ -39,12 +44,23 @@ if platform.system() == 'Linux':
|
|||
write_eventfd as write_eventfd,
|
||||
read_eventfd as read_eventfd,
|
||||
close_eventfd as close_eventfd,
|
||||
EFDReadCancelled as EFDReadCancelled,
|
||||
EventFD as EventFD,
|
||||
)
|
||||
|
||||
from ._ringbuf import (
|
||||
RBToken as RBToken,
|
||||
open_ringbuf as open_ringbuf,
|
||||
RingBuffSender as RingBuffSender,
|
||||
RingBuffReceiver as RingBuffReceiver,
|
||||
open_ringbuf as open_ringbuf
|
||||
open_ringbuf_pair as open_ringbuf_pair,
|
||||
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
|
||||
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
|
||||
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
|
||||
RingBuffBytesSender as RingBuffBytesSender,
|
||||
RingBuffBytesReceiver as RingBuffBytesReceiver,
|
||||
RingBuffChannel as RingBuffChannel,
|
||||
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel,
|
||||
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
|
||||
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
|
||||
)
|
||||
|
|
|
@ -29,15 +29,19 @@ from pprint import pformat
|
|||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Type
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from tractor.ipc._transport import MsgTransport
|
||||
from tractor.ipc._tcp import (
|
||||
MsgpackTCPStream,
|
||||
get_stream_addrs
|
||||
from tractor.ipc._types import (
|
||||
transport_from_addr,
|
||||
transport_from_stream,
|
||||
)
|
||||
from tractor._addr import (
|
||||
wrap_address,
|
||||
Address,
|
||||
AddressTypes
|
||||
)
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
|
@ -52,17 +56,6 @@ log = get_logger(__name__)
|
|||
_is_windows = platform.system() == 'Windows'
|
||||
|
||||
|
||||
def get_msg_transport(
|
||||
|
||||
key: tuple[str, str],
|
||||
|
||||
) -> Type[MsgTransport]:
|
||||
|
||||
return {
|
||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||
}[key]
|
||||
|
||||
|
||||
class Channel:
|
||||
'''
|
||||
An inter-process channel for communication between (remote) actors.
|
||||
|
@ -77,10 +70,7 @@ class Channel:
|
|||
def __init__(
|
||||
|
||||
self,
|
||||
destaddr: tuple[str, int]|None,
|
||||
|
||||
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
|
||||
|
||||
transport: MsgTransport|None = None,
|
||||
# TODO: optional reconnection support?
|
||||
# auto_reconnect: bool = False,
|
||||
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||
|
@ -90,13 +80,9 @@ class Channel:
|
|||
# self._recon_seq = on_reconnect
|
||||
# self._autorecon = auto_reconnect
|
||||
|
||||
self._destaddr = destaddr
|
||||
self._transport_key = msg_transport_type_key
|
||||
|
||||
# Either created in ``.connect()`` or passed in by
|
||||
# user in ``.from_stream()``.
|
||||
self._stream: trio.SocketStream|None = None
|
||||
self._transport: MsgTransport|None = None
|
||||
self._transport: MsgTransport|None = transport
|
||||
|
||||
# set after handshake - always uid of far end
|
||||
self.uid: tuple[str, str]|None = None
|
||||
|
@ -110,6 +96,10 @@ class Channel:
|
|||
# runtime.
|
||||
self._cancel_called: bool = False
|
||||
|
||||
@property
|
||||
def stream(self) -> trio.abc.Stream | None:
|
||||
return self._transport.stream if self._transport else None
|
||||
|
||||
@property
|
||||
def msgstream(self) -> MsgTransport:
|
||||
log.info(
|
||||
|
@ -124,52 +114,32 @@ class Channel:
|
|||
@classmethod
|
||||
def from_stream(
|
||||
cls,
|
||||
stream: trio.SocketStream,
|
||||
**kwargs,
|
||||
|
||||
stream: trio.abc.Stream,
|
||||
) -> Channel:
|
||||
|
||||
src, dst = get_stream_addrs(stream)
|
||||
chan = Channel(
|
||||
destaddr=dst,
|
||||
**kwargs,
|
||||
transport_cls = transport_from_stream(stream)
|
||||
return Channel(
|
||||
transport=transport_cls(stream)
|
||||
)
|
||||
|
||||
# set immediately here from provided instance
|
||||
chan._stream: trio.SocketStream = stream
|
||||
chan.set_msg_transport(stream)
|
||||
return chan
|
||||
@classmethod
|
||||
async def from_addr(
|
||||
cls,
|
||||
addr: AddressTypes,
|
||||
**kwargs
|
||||
) -> Channel:
|
||||
addr: Address = wrap_address(addr)
|
||||
transport_cls = transport_from_addr(addr)
|
||||
transport = await transport_cls.connect_to(addr, **kwargs)
|
||||
|
||||
def set_msg_transport(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
type_key: tuple[str, str]|None = None,
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
codec: MsgCodec|None = None,
|
||||
|
||||
) -> MsgTransport:
|
||||
type_key = (
|
||||
type_key
|
||||
or
|
||||
self._transport_key
|
||||
log.transport(
|
||||
f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
|
||||
)
|
||||
# get transport type, then
|
||||
self._transport = get_msg_transport(
|
||||
type_key
|
||||
# instantiate an instance of the msg-transport
|
||||
)(
|
||||
stream,
|
||||
codec=codec,
|
||||
)
|
||||
return self._transport
|
||||
return Channel(transport=transport)
|
||||
|
||||
@cm
|
||||
def apply_codec(
|
||||
self,
|
||||
codec: MsgCodec,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Temporarily override the underlying IPC msg codec for
|
||||
|
@ -189,44 +159,20 @@ class Channel:
|
|||
return '<Channel with inactive transport?>'
|
||||
|
||||
return repr(
|
||||
self._transport.stream.socket._sock
|
||||
self._transport
|
||||
).replace( # type: ignore
|
||||
"socket.socket",
|
||||
"Channel",
|
||||
)
|
||||
|
||||
@property
|
||||
def laddr(self) -> tuple[str, int]|None:
|
||||
def laddr(self) -> Address|None:
|
||||
return self._transport.laddr if self._transport else None
|
||||
|
||||
@property
|
||||
def raddr(self) -> tuple[str, int]|None:
|
||||
def raddr(self) -> Address|None:
|
||||
return self._transport.raddr if self._transport else None
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
destaddr: tuple[Any, ...] | None = None,
|
||||
**kwargs
|
||||
|
||||
) -> MsgTransport:
|
||||
|
||||
if self.connected():
|
||||
raise RuntimeError("channel is already connected?")
|
||||
|
||||
destaddr = destaddr or self._destaddr
|
||||
assert isinstance(destaddr, tuple)
|
||||
|
||||
stream = await trio.open_tcp_stream(
|
||||
*destaddr,
|
||||
**kwargs
|
||||
)
|
||||
transport = self.set_msg_transport(stream)
|
||||
|
||||
log.transport(
|
||||
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
|
||||
)
|
||||
return transport
|
||||
|
||||
# TODO: something like,
|
||||
# `pdbp.hideframe_on(errors=[MsgTypeError])`
|
||||
# instead of the `try/except` hack we have rn..
|
||||
|
@ -261,7 +207,11 @@ class Channel:
|
|||
# assert err
|
||||
__tracebackhide__: bool = False
|
||||
else:
|
||||
assert err.cid
|
||||
try:
|
||||
assert err.cid
|
||||
|
||||
except KeyError:
|
||||
raise err
|
||||
|
||||
raise
|
||||
|
||||
|
@ -388,17 +338,14 @@ class Channel:
|
|||
|
||||
@acm
|
||||
async def _connect_chan(
|
||||
host: str,
|
||||
port: int
|
||||
|
||||
addr: AddressTypes
|
||||
) -> typing.AsyncGenerator[Channel, None]:
|
||||
'''
|
||||
Create and connect a channel with disconnect on context manager
|
||||
teardown.
|
||||
|
||||
'''
|
||||
chan = Channel((host, port))
|
||||
await chan.connect()
|
||||
chan = await Channel.from_addr(addr)
|
||||
yield chan
|
||||
with trio.CancelScope(shield=True):
|
||||
await chan.aclose()
|
||||
|
|
|
@ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int:
|
|||
raise OSError(errno.errorcode[ffi.errno], 'close failed')
|
||||
|
||||
|
||||
class EFDReadCancelled(Exception):
|
||||
...
|
||||
|
||||
|
||||
class EventFD:
|
||||
'''
|
||||
Use a previously opened eventfd(2), meant to be used in
|
||||
|
@ -124,6 +128,7 @@ class EventFD:
|
|||
self._fd: int = fd
|
||||
self._omode: str = omode
|
||||
self._fobj = None
|
||||
self._cscope: trio.CancelScope | None = None
|
||||
|
||||
@property
|
||||
def fd(self) -> int | None:
|
||||
|
@ -133,17 +138,46 @@ class EventFD:
|
|||
return write_eventfd(self._fd, value)
|
||||
|
||||
async def read(self) -> int:
|
||||
return await trio.to_thread.run_sync(
|
||||
read_eventfd, self._fd,
|
||||
abandon_on_cancel=True
|
||||
)
|
||||
'''
|
||||
Async wrapper for `read_eventfd(self.fd)`
|
||||
|
||||
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
|
||||
in order to make it cancellable when `self.close()` is called.
|
||||
|
||||
'''
|
||||
self._cscope = trio.CancelScope()
|
||||
with self._cscope:
|
||||
return await trio.to_thread.run_sync(
|
||||
read_eventfd, self._fd,
|
||||
abandon_on_cancel=True
|
||||
)
|
||||
|
||||
if self._cscope.cancelled_caught:
|
||||
raise EFDReadCancelled
|
||||
|
||||
self._cscope = None
|
||||
|
||||
def read_direct(self) -> int:
|
||||
'''
|
||||
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
|
||||
opened with `EFD_NONBLOCK` its gonna block the thread.
|
||||
|
||||
'''
|
||||
return read_eventfd(self._fd)
|
||||
|
||||
def open(self):
|
||||
self._fobj = os.fdopen(self._fd, self._omode)
|
||||
|
||||
def close(self):
|
||||
if self._fobj:
|
||||
self._fobj.close()
|
||||
try:
|
||||
self._fobj.close()
|
||||
|
||||
except OSError:
|
||||
...
|
||||
|
||||
if self._cscope:
|
||||
self._cscope.cancel()
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
|
|
|
@ -18,7 +18,15 @@ IPC Reliable RingBuffer implementation
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from contextlib import contextmanager as cm
|
||||
import struct
|
||||
from typing import (
|
||||
ContextManager,
|
||||
AsyncContextManager
|
||||
)
|
||||
from contextlib import (
|
||||
contextmanager as cm,
|
||||
asynccontextmanager as acm
|
||||
)
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
import trio
|
||||
|
@ -28,25 +36,37 @@ from msgspec import (
|
|||
)
|
||||
|
||||
from ._linux import (
|
||||
EFD_NONBLOCK,
|
||||
open_eventfd,
|
||||
EFDReadCancelled,
|
||||
EventFD
|
||||
)
|
||||
from ._mp_bs import disable_mantracker
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
InternalError
|
||||
)
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
disable_mantracker()
|
||||
|
||||
_DEFAULT_RB_SIZE = 10 * 1024
|
||||
|
||||
|
||||
class RBToken(Struct, frozen=True):
|
||||
'''
|
||||
RingBuffer token contains necesary info to open the two
|
||||
RingBuffer token contains necesary info to open the three
|
||||
eventfds and the shared memory
|
||||
|
||||
'''
|
||||
shm_name: str
|
||||
write_eventfd: int
|
||||
wrap_eventfd: int
|
||||
|
||||
write_eventfd: int # used to signal writer ptr advance
|
||||
wrap_eventfd: int # used to signal reader ready after wrap around
|
||||
eof_eventfd: int # used to signal writer closed
|
||||
|
||||
buf_size: int
|
||||
|
||||
def as_msg(self):
|
||||
|
@ -59,62 +79,97 @@ class RBToken(Struct, frozen=True):
|
|||
|
||||
return RBToken(**msg)
|
||||
|
||||
@property
|
||||
def fds(self) -> tuple[int, int, int]:
|
||||
'''
|
||||
Useful for `pass_fds` params
|
||||
|
||||
'''
|
||||
return (
|
||||
self.write_eventfd,
|
||||
self.wrap_eventfd,
|
||||
self.eof_eventfd
|
||||
)
|
||||
|
||||
|
||||
@cm
|
||||
def open_ringbuf(
|
||||
shm_name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
write_efd_flags: int = 0,
|
||||
wrap_efd_flags: int = 0
|
||||
) -> RBToken:
|
||||
buf_size: int = _DEFAULT_RB_SIZE,
|
||||
) -> ContextManager[RBToken]:
|
||||
'''
|
||||
Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to
|
||||
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
|
||||
|
||||
'''
|
||||
shm = SharedMemory(
|
||||
name=shm_name,
|
||||
size=buf_size,
|
||||
create=True
|
||||
)
|
||||
try:
|
||||
token = RBToken(
|
||||
shm_name=shm_name,
|
||||
write_eventfd=open_eventfd(flags=write_efd_flags),
|
||||
wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
|
||||
buf_size=buf_size
|
||||
)
|
||||
yield token
|
||||
with (
|
||||
EventFD(open_eventfd(), 'r') as write_event,
|
||||
EventFD(open_eventfd(), 'r') as wrap_event,
|
||||
EventFD(open_eventfd(), 'r') as eof_event,
|
||||
):
|
||||
token = RBToken(
|
||||
shm_name=shm_name,
|
||||
write_eventfd=write_event.fd,
|
||||
wrap_eventfd=wrap_event.fd,
|
||||
eof_eventfd=eof_event.fd,
|
||||
buf_size=buf_size
|
||||
)
|
||||
yield token
|
||||
|
||||
finally:
|
||||
shm.unlink()
|
||||
|
||||
|
||||
Buffer = bytes | bytearray | memoryview
|
||||
|
||||
|
||||
'''
|
||||
IPC Reliable Ring Buffer
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, to signal writes to
|
||||
the reader and end of stream.
|
||||
|
||||
'''
|
||||
|
||||
|
||||
class RingBuffSender(trio.abc.SendStream):
|
||||
'''
|
||||
IPC Reliable Ring Buffer sender side implementation
|
||||
Ring Buffer sender side implementation
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
Do not use directly! manage with `attach_to_ringbuf_sender`
|
||||
after having opened a ringbuf context with `open_ringbuf`.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
cleanup: bool = False
|
||||
):
|
||||
token = RBToken.from_msg(token)
|
||||
self._shm = SharedMemory(
|
||||
name=token.shm_name,
|
||||
size=token.buf_size,
|
||||
create=False
|
||||
)
|
||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self._token = RBToken.from_msg(token)
|
||||
self._shm: SharedMemory | None = None
|
||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||
self._ptr = 0
|
||||
|
||||
self._cleanup = cleanup
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
def name(self) -> str:
|
||||
if not self._shm:
|
||||
raise ValueError('shared memory not initialized yet!')
|
||||
return self._shm.name
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._shm.size
|
||||
return self._token.buf_size
|
||||
|
||||
@property
|
||||
def ptr(self) -> int:
|
||||
|
@ -128,73 +183,97 @@ class RingBuffSender(trio.abc.SendStream):
|
|||
def wrap_fd(self) -> int:
|
||||
return self._wrap_event.fd
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview):
|
||||
# while data is larger than the remaining buf
|
||||
target_ptr = self.ptr + len(data)
|
||||
while target_ptr > self.size:
|
||||
# write all bytes that fit
|
||||
remaining = self.size - self.ptr
|
||||
self._shm.buf[self.ptr:] = data[:remaining]
|
||||
# signal write and wait for reader wrap around
|
||||
self._write_event.write(remaining)
|
||||
await self._wrap_event.read()
|
||||
async def _wait_wrap(self):
|
||||
await self._wrap_event.read()
|
||||
|
||||
# wrap around and trim already written bytes
|
||||
self._ptr = 0
|
||||
data = data[remaining:]
|
||||
target_ptr = self._ptr + len(data)
|
||||
async def send_all(self, data: Buffer):
|
||||
async with self._send_lock:
|
||||
# while data is larger than the remaining buf
|
||||
target_ptr = self.ptr + len(data)
|
||||
while target_ptr > self.size:
|
||||
# write all bytes that fit
|
||||
remaining = self.size - self.ptr
|
||||
self._shm.buf[self.ptr:] = data[:remaining]
|
||||
# signal write and wait for reader wrap around
|
||||
self._write_event.write(remaining)
|
||||
await self._wait_wrap()
|
||||
|
||||
# remaining data fits on buffer
|
||||
self._shm.buf[self.ptr:target_ptr] = data
|
||||
self._write_event.write(len(data))
|
||||
self._ptr = target_ptr
|
||||
# wrap around and trim already written bytes
|
||||
self._ptr = 0
|
||||
data = data[remaining:]
|
||||
target_ptr = self._ptr + len(data)
|
||||
|
||||
# remaining data fits on buffer
|
||||
self._shm.buf[self.ptr:target_ptr] = data
|
||||
self._write_event.write(len(data))
|
||||
self._ptr = target_ptr
|
||||
|
||||
async def wait_send_all_might_not_block(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def aclose(self):
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
def open(self):
|
||||
self._shm = SharedMemory(
|
||||
name=self._token.shm_name,
|
||||
size=self._token.buf_size,
|
||||
create=False
|
||||
)
|
||||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
|
||||
def close(self):
|
||||
self._eof_event.write(
|
||||
self._ptr if self._ptr > 0 else self.size
|
||||
)
|
||||
|
||||
if self._cleanup:
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def aclose(self):
|
||||
async with self._send_lock:
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
|
||||
class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||
'''
|
||||
IPC Reliable Ring Buffer receiver side implementation
|
||||
Ring Buffer receiver side implementation
|
||||
|
||||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
Do not use directly! manage with `attach_to_ringbuf_receiver`
|
||||
after having opened a ringbuf context with `open_ringbuf`.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
flags: int = 0
|
||||
cleanup: bool = True,
|
||||
):
|
||||
token = RBToken.from_msg(token)
|
||||
self._shm = SharedMemory(
|
||||
name=token.shm_name,
|
||||
size=token.buf_size,
|
||||
create=False
|
||||
)
|
||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self._flags = flags
|
||||
self._token = RBToken.from_msg(token)
|
||||
self._shm: SharedMemory | None = None
|
||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||
self._eof_event = EventFD(self._token.eof_eventfd, 'r')
|
||||
self._ptr: int = 0
|
||||
self._write_ptr: int = 0
|
||||
self._end_ptr: int = -1
|
||||
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
def name(self) -> str:
|
||||
if not self._shm:
|
||||
raise ValueError('shared memory not initialized yet!')
|
||||
return self._shm.name
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._shm.size
|
||||
return self._token.buf_size
|
||||
|
||||
@property
|
||||
def ptr(self) -> int:
|
||||
|
@ -208,46 +287,368 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
|||
def wrap_fd(self) -> int:
|
||||
return self._wrap_event.fd
|
||||
|
||||
async def receive_some(
|
||||
self,
|
||||
max_bytes: int | None = None,
|
||||
nb_timeout: float = 0.1
|
||||
) -> memoryview:
|
||||
# if non blocking eventfd enabled, do polling
|
||||
# until next write, this allows signal handling
|
||||
if self._flags | EFD_NONBLOCK:
|
||||
delta = None
|
||||
while delta is None:
|
||||
async def _eof_monitor_task(self):
|
||||
'''
|
||||
Long running EOF event monitor, automatically run in bg by
|
||||
`attach_to_ringbuf_receiver` context manager, if EOF event
|
||||
is set its value will be the end pointer (highest valid
|
||||
index to be read from buf, after setting the `self._end_ptr`
|
||||
we close the write event which should cancel any blocked
|
||||
`self._write_event.read()`s on it.
|
||||
|
||||
'''
|
||||
try:
|
||||
self._end_ptr = await self._eof_event.read()
|
||||
self._write_event.close()
|
||||
|
||||
except EFDReadCancelled:
|
||||
...
|
||||
|
||||
except trio.Cancelled:
|
||||
...
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
'''
|
||||
Receive up to `max_bytes`, if no `max_bytes` is provided
|
||||
a reasonable default is used.
|
||||
|
||||
'''
|
||||
if max_bytes is None:
|
||||
max_bytes: int = _DEFAULT_RB_SIZE
|
||||
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
|
||||
# delta is remaining bytes we havent read
|
||||
delta = self._write_ptr - self._ptr
|
||||
if delta == 0:
|
||||
# we have read all we can, see if new data is available
|
||||
if self._end_ptr < 0:
|
||||
# if we havent been signaled about EOF yet
|
||||
try:
|
||||
delta = await self._write_event.read()
|
||||
self._write_ptr += delta
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == 'EAGAIN':
|
||||
continue
|
||||
except EFDReadCancelled:
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
# this means writer signaled EOF
|
||||
if self._end_ptr > 0:
|
||||
# final self._write_ptr modification and recalculate delta
|
||||
self._write_ptr = self._end_ptr
|
||||
delta = self._end_ptr - self._ptr
|
||||
|
||||
raise e
|
||||
else:
|
||||
# shouldnt happen cause self._eof_monitor_task always sets
|
||||
# self._end_ptr before closing self._write_event
|
||||
raise InternalError(
|
||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||
)
|
||||
|
||||
else:
|
||||
delta = await self._write_event.read()
|
||||
else:
|
||||
# no more bytes to read and self._end_ptr set, EOF reached
|
||||
return b''
|
||||
|
||||
# dont overflow caller
|
||||
delta = min(delta, max_bytes)
|
||||
|
||||
target_ptr = self._ptr + delta
|
||||
|
||||
# fetch next segment and advance ptr
|
||||
next_ptr = self._ptr + delta
|
||||
segment = self._shm.buf[self._ptr:next_ptr]
|
||||
self._ptr = next_ptr
|
||||
segment = bytes(self._shm.buf[self._ptr:target_ptr])
|
||||
self._ptr = target_ptr
|
||||
|
||||
if self.ptr == self.size:
|
||||
if self._ptr == self.size:
|
||||
# reached the end, signal wrap around
|
||||
self._ptr = 0
|
||||
self._write_ptr = 0
|
||||
self._wrap_event.write(1)
|
||||
|
||||
return segment
|
||||
|
||||
async def aclose(self):
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
def open(self):
|
||||
self._shm = SharedMemory(
|
||||
name=self._token.shm_name,
|
||||
size=self._token.buf_size,
|
||||
create=False
|
||||
)
|
||||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
|
||||
def close(self):
|
||||
if self._cleanup:
|
||||
self._write_event.close()
|
||||
self._wrap_event.close()
|
||||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
async def aclose(self):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_receiver(
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
) -> AsyncContextManager[RingBuffReceiver]:
|
||||
'''
|
||||
Attach a RingBuffReceiver from a previously opened
|
||||
RBToken.
|
||||
|
||||
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBuffReceiver(
|
||||
token,
|
||||
cleanup=cleanup
|
||||
) as receiver
|
||||
):
|
||||
n.start_soon(receiver._eof_monitor_task)
|
||||
yield receiver
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_sender(
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
) -> AsyncContextManager[RingBuffSender]:
|
||||
'''
|
||||
Attach a RingBuffSender from a previously opened
|
||||
RBToken.
|
||||
|
||||
'''
|
||||
async with RingBuffSender(
|
||||
token,
|
||||
cleanup=cleanup
|
||||
) as sender:
|
||||
yield sender
|
||||
|
||||
|
||||
@cm
|
||||
def open_ringbuf_pair(
|
||||
name: str,
|
||||
buf_size: int = _DEFAULT_RB_SIZE
|
||||
) -> ContextManager[tuple(RBToken, RBToken)]:
|
||||
'''
|
||||
Handle resources for a ringbuf pair to be used for
|
||||
bidirectional messaging.
|
||||
|
||||
'''
|
||||
with (
|
||||
open_ringbuf(
|
||||
name + '.pair0',
|
||||
buf_size=buf_size
|
||||
) as token_0,
|
||||
|
||||
open_ringbuf(
|
||||
name + '.pair1',
|
||||
buf_size=buf_size
|
||||
) as token_1
|
||||
):
|
||||
yield token_0, token_1
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_stream(
|
||||
token_in: RBToken,
|
||||
token_out: RBToken,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True
|
||||
) -> AsyncContextManager[trio.StapledStream]:
|
||||
'''
|
||||
Attach a trio.StapledStream from a previously opened
|
||||
ringbuf pair.
|
||||
|
||||
'''
|
||||
async with (
|
||||
attach_to_ringbuf_receiver(
|
||||
token_in,
|
||||
cleanup=cleanup_in
|
||||
) as receiver,
|
||||
attach_to_ringbuf_sender(
|
||||
token_out,
|
||||
cleanup=cleanup_out
|
||||
) as sender,
|
||||
):
|
||||
yield trio.StapledStream(sender, receiver)
|
||||
|
||||
|
||||
|
||||
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
|
||||
'''
|
||||
In order to guarantee full messages are received, all bytes
|
||||
sent by `RingBuffBytesSender` are preceded with a 4 byte header
|
||||
which decodes into a uint32 indicating the actual size of the
|
||||
next payload.
|
||||
|
||||
Optional batch mode:
|
||||
|
||||
If `batch_size` > 1 messages wont get sent immediately but will be
|
||||
stored until `batch_size` messages are pending, then it will send
|
||||
them all at once.
|
||||
|
||||
`batch_size` can be changed dynamically but always call, `flush()`
|
||||
right before.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
sender: RingBuffSender,
|
||||
batch_size: int = 1
|
||||
):
|
||||
self._sender = sender
|
||||
self.batch_size = batch_size
|
||||
self._batch_msg_len = 0
|
||||
self._batch: bytes = b''
|
||||
|
||||
async def flush(self) -> None:
|
||||
await self._sender.send_all(self._batch)
|
||||
self._batch = b''
|
||||
self._batch_msg_len = 0
|
||||
|
||||
async def send(self, value: bytes) -> None:
|
||||
msg: bytes = struct.pack("<I", len(value)) + value
|
||||
if self.batch_size == 1:
|
||||
await self._sender.send_all(msg)
|
||||
return
|
||||
|
||||
self._batch += msg
|
||||
self._batch_msg_len += 1
|
||||
if self._batch_msg_len == self.batch_size:
|
||||
await self.flush()
|
||||
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._sender.aclose()
|
||||
|
||||
|
||||
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
|
||||
'''
|
||||
See `RingBuffBytesSender` docstring.
|
||||
|
||||
A `tricycle.BufferedReceiveStream` is used for the
|
||||
`receive_exactly` API.
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
receiver: RingBuffReceiver
|
||||
):
|
||||
self._receiver = receiver
|
||||
|
||||
async def _receive_exactly(self, num_bytes: int) -> bytes:
|
||||
'''
|
||||
Fetch bytes from receiver until we read exactly `num_bytes`
|
||||
or end of stream is signaled.
|
||||
|
||||
'''
|
||||
payload = b''
|
||||
while len(payload) < num_bytes:
|
||||
remaining = num_bytes - len(payload)
|
||||
|
||||
new_bytes = await self._receiver.receive_some(
|
||||
max_bytes=remaining
|
||||
)
|
||||
|
||||
if new_bytes == b'':
|
||||
raise trio.EndOfChannel
|
||||
|
||||
payload += new_bytes
|
||||
|
||||
return payload
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
header: bytes = await self._receive_exactly(4)
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
if size == 0:
|
||||
raise trio.EndOfChannel
|
||||
return await self._receive_exactly(size)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._receiver.aclose()
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_rchannel(
|
||||
token: RBToken,
|
||||
cleanup: bool = True
|
||||
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
||||
'''
|
||||
Attach a RingBuffBytesReceiver from a previously opened
|
||||
RBToken.
|
||||
'''
|
||||
async with attach_to_ringbuf_receiver(
|
||||
token, cleanup=cleanup
|
||||
) as receiver:
|
||||
yield RingBuffBytesReceiver(receiver)
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_schannel(
|
||||
token: RBToken,
|
||||
cleanup: bool = True,
|
||||
batch_size: int = 1,
|
||||
) -> AsyncContextManager[RingBuffBytesSender]:
|
||||
'''
|
||||
Attach a RingBuffBytesSender from a previously opened
|
||||
RBToken.
|
||||
'''
|
||||
async with attach_to_ringbuf_sender(
|
||||
token, cleanup=cleanup
|
||||
) as sender:
|
||||
yield RingBuffBytesSender(sender, batch_size=batch_size)
|
||||
|
||||
|
||||
class RingBuffChannel(trio.abc.Channel[bytes]):
|
||||
'''
|
||||
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
|
||||
in order to expose the bidirectional `trio.abc.Channel` API.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
sender: RingBuffBytesSender,
|
||||
receiver: RingBuffBytesReceiver
|
||||
):
|
||||
self._sender = sender
|
||||
self._receiver = receiver
|
||||
|
||||
async def send(self, value: bytes):
|
||||
await self._sender.send(value)
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
return await self._receiver.receive()
|
||||
|
||||
async def aclose(self):
|
||||
await self._receiver.aclose()
|
||||
await self._sender.aclose()
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_to_ringbuf_channel(
|
||||
token_in: RBToken,
|
||||
token_out: RBToken,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True
|
||||
) -> AsyncContextManager[RingBuffChannel]:
|
||||
'''
|
||||
Attach to an already opened ringbuf pair and return
|
||||
a `RingBuffChannel`.
|
||||
|
||||
'''
|
||||
async with (
|
||||
attach_to_ringbuf_rchannel(
|
||||
token_in,
|
||||
cleanup=cleanup_in
|
||||
) as receiver,
|
||||
attach_to_ringbuf_schannel(
|
||||
token_out,
|
||||
cleanup=cleanup_out
|
||||
) as sender,
|
||||
):
|
||||
yield RingBuffChannel(sender, receiver)
|
||||
|
|
|
@ -18,389 +18,88 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
)
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Type,
|
||||
)
|
||||
|
||||
import msgspec
|
||||
from tricycle import BufferedReceiveStream
|
||||
import trio
|
||||
|
||||
from tractor.msg import MsgCodec
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
MsgTypeError,
|
||||
TransportClosed,
|
||||
_mk_send_mte,
|
||||
_mk_recv_mte,
|
||||
)
|
||||
from tractor.msg import (
|
||||
_ctxvar_MsgCodec,
|
||||
# _codec, XXX see `self._codec` sanity/debug checks
|
||||
MsgCodec,
|
||||
types as msgtypes,
|
||||
pretty_struct,
|
||||
)
|
||||
from tractor.ipc import MsgTransport
|
||||
from tractor._addr import TCPAddress
|
||||
from tractor.ipc._transport import MsgpackTransport
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def get_stream_addrs(
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[
|
||||
tuple[str, int], # local
|
||||
tuple[str, int], # remote
|
||||
]:
|
||||
'''
|
||||
Return the `trio` streaming transport prot's socket-addrs for
|
||||
both the local and remote sides as a pair.
|
||||
|
||||
'''
|
||||
# rn, should both be IP sockets
|
||||
lsockname = stream.socket.getsockname()
|
||||
rsockname = stream.socket.getpeername()
|
||||
return (
|
||||
tuple(lsockname[:2]),
|
||||
tuple(rsockname[:2]),
|
||||
)
|
||||
|
||||
|
||||
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
||||
# seems to be an issue with `get_msg_transport()` returning
|
||||
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
||||
class MsgpackTCPStream(MsgTransport):
|
||||
class MsgpackTCPStream(MsgpackTransport):
|
||||
'''
|
||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using the ``msgspec`` codec lib.
|
||||
|
||||
'''
|
||||
address_type = TCPAddress
|
||||
layer_key: int = 4
|
||||
name_key: str = 'tcp'
|
||||
|
||||
# TODO: better naming for this?
|
||||
# -[ ] check how libp2p does naming for such things?
|
||||
codec_key: str = 'msgpack'
|
||||
# def __init__(
|
||||
# self,
|
||||
# stream: trio.SocketStream,
|
||||
# prefix_size: int = 4,
|
||||
# codec: CodecType = None,
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
prefix_size: int = 4,
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# stream,
|
||||
# prefix_size=prefix_size,
|
||||
# codec=codec
|
||||
# )
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
#
|
||||
# TODO: define this as a `Codec` struct which can be
|
||||
# overriden dynamically by the application/runtime?
|
||||
codec: tuple[
|
||||
Callable[[Any], Any]|None, # coder
|
||||
Callable[[type, Any], Any]|None, # decoder
|
||||
]|None = None,
|
||||
@property
|
||||
def maddr(self) -> str:
|
||||
host, port = self.raddr.unwrap()
|
||||
return (
|
||||
f'/ipv4/{host}'
|
||||
f'/{self.address_type.name_key}/{port}'
|
||||
# f'/{self.chan.uid[0]}'
|
||||
# f'/{self.cid}'
|
||||
|
||||
) -> None:
|
||||
|
||||
self.stream = stream
|
||||
assert self.stream.socket
|
||||
|
||||
# should both be IP sockets
|
||||
self._laddr, self._raddr = get_stream_addrs(stream)
|
||||
|
||||
# create read loop instance
|
||||
self._aiter_pkts = self._iter_packets()
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# public i guess?
|
||||
self.drained: list[dict] = []
|
||||
|
||||
self.recv_stream = BufferedReceiveStream(
|
||||
transport_stream=stream
|
||||
# f'/cid={cid_head}..{cid_tail}'
|
||||
# TODO: ? not use this ^ right ?
|
||||
)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
# allow for custom IPC msg interchange format
|
||||
# dynamic override Bo
|
||||
self._task = trio.lowlevel.current_task()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# self._codec: MsgCodec = (
|
||||
# codec
|
||||
# or
|
||||
# _codec._ctxvar_MsgCodec.get()
|
||||
# )
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||
stream using the current task's `MsgCodec`.
|
||||
|
||||
This is a streaming routine implemented as an async generator
|
||||
func (which was the original design, but could be changed?)
|
||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||
it is assigned to the `._aiter_pkts` attr.
|
||||
|
||||
'''
|
||||
decodes_failed: int = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||
except (
|
||||
ValueError,
|
||||
ConnectionResetError,
|
||||
|
||||
# not sure entirely why we need this but without it we
|
||||
# seem to be getting racy failures here on
|
||||
# arbiter/registry name subs..
|
||||
trio.BrokenResourceError,
|
||||
|
||||
) as trans_err:
|
||||
|
||||
loglevel = 'transport'
|
||||
match trans_err:
|
||||
# case (
|
||||
# ConnectionResetError()
|
||||
# ):
|
||||
# loglevel = 'transport'
|
||||
|
||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||
# seems to raise a 0-bytes-read?
|
||||
case ValueError() if (
|
||||
'unclean EOF' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# peer actor (task) prolly shutdown quickly due
|
||||
# to cancellation
|
||||
case trio.BrokenResourceError() if (
|
||||
'Connection reset by peer' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# unless the disconnect condition falls under "a
|
||||
# normal operation breakage" we usualy console warn
|
||||
# about it.
|
||||
case _:
|
||||
loglevel: str = 'warning'
|
||||
|
||||
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already closed by peer\n'
|
||||
f'x)> {type(trans_err)}\n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel=loglevel,
|
||||
) from trans_err
|
||||
|
||||
# XXX definitely can happen if transport is closed
|
||||
# manually by another `trio.lowlevel.Task` in the
|
||||
# same actor; we use this in some simulated fault
|
||||
# testing for ex, but generally should never happen
|
||||
# under normal operation!
|
||||
#
|
||||
# NOTE: as such we always re-raise this error from the
|
||||
# RPC msg loop!
|
||||
except trio.ClosedResourceError as closure_err:
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already manually closed locally?\n'
|
||||
f'x)> {type(closure_err)} \n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel='error',
|
||||
raise_on_report=(
|
||||
closure_err.args[0] == 'another task closed this fd'
|
||||
or
|
||||
closure_err.args[0] in ['another task closed this fd']
|
||||
),
|
||||
) from closure_err
|
||||
|
||||
# graceful TCP EOF disconnect
|
||||
if header == b'':
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already gracefully closed\n'
|
||||
f')>\n'
|
||||
f'|_{self}\n'
|
||||
),
|
||||
loglevel='transport',
|
||||
# cause=??? # handy or no?
|
||||
)
|
||||
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
|
||||
log.transport(f'received header {size}') # type: ignore
|
||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.transport(f"received {msg_bytes}") # type: ignore
|
||||
try:
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# assert (
|
||||
# task := trio.lowlevel.current_task()
|
||||
# ) is not self._task
|
||||
# self._task = task
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.recv()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg_bytes: {msg_bytes}\n'
|
||||
# )
|
||||
yield codec.decode(msg_bytes)
|
||||
|
||||
# XXX NOTE: since the below error derives from
|
||||
# `DecodeError` we need to catch is specially
|
||||
# and always raise such that spec violations
|
||||
# are never allowed to be caught silently!
|
||||
except msgspec.ValidationError as verr:
|
||||
msgtyperr: MsgTypeError = _mk_recv_mte(
|
||||
msg=msg_bytes,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
)
|
||||
# XXX deliver up to `Channel.recv()` where
|
||||
# a re-raise and `Error`-pack can inject the far
|
||||
# end actor `.uid`.
|
||||
yield msgtyperr
|
||||
|
||||
except (
|
||||
msgspec.DecodeError,
|
||||
UnicodeDecodeError,
|
||||
):
|
||||
if decodes_failed < 4:
|
||||
# ignore decoding errors for now and assume they have to
|
||||
# do with a channel drop - hope that receiving from the
|
||||
# channel will raise an expected error and bubble up.
|
||||
try:
|
||||
msg_str: str|bytes = msg_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
msg_str = msg_bytes
|
||||
|
||||
log.exception(
|
||||
'Failed to decode msg?\n'
|
||||
f'{codec}\n\n'
|
||||
'Rxed bytes from wire:\n\n'
|
||||
f'{msg_str!r}\n'
|
||||
)
|
||||
decodes_failed += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
async def send(
|
||||
self,
|
||||
msg: msgtypes.MsgType,
|
||||
|
||||
strict_types: bool = True,
|
||||
hide_tb: bool = False,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||
|
||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||
invalid msg type
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = hide_tb
|
||||
|
||||
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
||||
# on the `.acquire()`/`.release()` sequencing..
|
||||
async with self._send_lock:
|
||||
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.send()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg: {msg}\n'
|
||||
# )
|
||||
|
||||
if type(msg) not in msgtypes.__msg_types__:
|
||||
if strict_types:
|
||||
raise _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
|
||||
try:
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
except TypeError as _err:
|
||||
typerr = _err
|
||||
msgtyperr: MsgTypeError = _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
message=(
|
||||
f'IPC-msg-spec violation in\n\n'
|
||||
f'{pretty_struct.Struct.pformat(msg)}'
|
||||
),
|
||||
src_type_error=typerr,
|
||||
)
|
||||
raise msgtyperr from typerr
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
size: bytes = struct.pack("<I", len(bytes_data))
|
||||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
# ?TODO? does it help ever to dynamically show this
|
||||
# frame?
|
||||
# try:
|
||||
# <the-above_code>
|
||||
# except BaseException as _err:
|
||||
# err = _err
|
||||
# if not isinstance(err, MsgTypeError):
|
||||
# __tracebackhide__: bool = False
|
||||
# raise
|
||||
|
||||
@property
|
||||
def laddr(self) -> tuple[str, int]:
|
||||
return self._laddr
|
||||
|
||||
@property
|
||||
def raddr(self) -> tuple[str, int]:
|
||||
return self._raddr
|
||||
|
||||
async def recv(self) -> Any:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
except TransportClosed:
|
||||
for msg in self.drained:
|
||||
yield msg
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter_pkts
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.stream.socket.fileno() != -1
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
destaddr: TCPAddress,
|
||||
prefix_size: int = 4,
|
||||
codec: MsgCodec|None = None,
|
||||
**kwargs
|
||||
) -> MsgpackTCPStream:
|
||||
stream = await trio.open_tcp_stream(
|
||||
*destaddr.unwrap(),
|
||||
**kwargs
|
||||
)
|
||||
return MsgpackTCPStream(
|
||||
stream,
|
||||
prefix_size=prefix_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[
|
||||
tuple[str, int],
|
||||
tuple[str, int]
|
||||
]:
|
||||
lsockname = stream.socket.getsockname()
|
||||
rsockname = stream.socket.getpeername()
|
||||
return (
|
||||
TCPAddress.from_addr(tuple(lsockname[:2])),
|
||||
TCPAddress.from_addr(tuple(rsockname[:2])),
|
||||
)
|
||||
|
|
|
@ -18,13 +18,45 @@ typing.Protocol based generic msg API, implement this class to add backends for
|
|||
tractor.ipc.Channel
|
||||
|
||||
'''
|
||||
import trio
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
runtime_checkable,
|
||||
Type,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
ClassVar
|
||||
)
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
)
|
||||
import struct
|
||||
|
||||
import trio
|
||||
import msgspec
|
||||
from tricycle import BufferedReceiveStream
|
||||
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
MsgTypeError,
|
||||
TransportClosed,
|
||||
_mk_send_mte,
|
||||
_mk_recv_mte,
|
||||
)
|
||||
from tractor.msg import (
|
||||
_ctxvar_MsgCodec,
|
||||
# _codec, XXX see `self._codec` sanity/debug checks
|
||||
MsgCodec,
|
||||
types as msgtypes,
|
||||
pretty_struct,
|
||||
)
|
||||
from tractor._addr import Address
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
# (codec, transport)
|
||||
MsgTransportKey = tuple[str, str]
|
||||
|
||||
|
||||
# from tractor.msg.types import MsgType
|
||||
|
@ -41,11 +73,11 @@ class MsgTransport(Protocol[MsgType]):
|
|||
# eventual msg definition/types?
|
||||
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||
|
||||
stream: trio.SocketStream
|
||||
stream: trio.abc.Stream
|
||||
drained: list[MsgType]
|
||||
|
||||
def __init__(self, stream: trio.SocketStream) -> None:
|
||||
...
|
||||
address_type: ClassVar[Type[Address]]
|
||||
codec_key: ClassVar[str]
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
async def send(self, msg: MsgType) -> None:
|
||||
|
@ -65,10 +97,354 @@ class MsgTransport(Protocol[MsgType]):
|
|||
def drain(self) -> AsyncIterator[dict]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> MsgTransportKey:
|
||||
return cls.codec_key, cls.address_type.name_key
|
||||
|
||||
@property
|
||||
def laddr(self) -> tuple[str, int]:
|
||||
def laddr(self) -> Address:
|
||||
...
|
||||
|
||||
@property
|
||||
def raddr(self) -> tuple[str, int]:
|
||||
def raddr(self) -> Address:
|
||||
...
|
||||
|
||||
@property
|
||||
def maddr(self) -> str:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
addr: Address,
|
||||
**kwargs
|
||||
) -> MsgTransport:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.abc.Stream
|
||||
) -> tuple[
|
||||
Address, # local
|
||||
Address # remote
|
||||
]:
|
||||
'''
|
||||
Return the `trio` streaming transport prot's addrs for both
|
||||
the local and remote sides as a pair.
|
||||
|
||||
'''
|
||||
...
|
||||
|
||||
|
||||
class MsgpackTransport(MsgTransport):
|
||||
|
||||
# TODO: better naming for this?
|
||||
# -[ ] check how libp2p does naming for such things?
|
||||
codec_key: str = 'msgpack'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.abc.Stream,
|
||||
prefix_size: int = 4,
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
#
|
||||
# TODO: define this as a `Codec` struct which can be
|
||||
# overriden dynamically by the application/runtime?
|
||||
codec: MsgCodec = None,
|
||||
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self._laddr, self._raddr = self.get_stream_addrs(stream)
|
||||
|
||||
# create read loop instance
|
||||
self._aiter_pkts = self._iter_packets()
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# public i guess?
|
||||
self.drained: list[dict] = []
|
||||
|
||||
self.recv_stream = BufferedReceiveStream(
|
||||
transport_stream=stream
|
||||
)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
# allow for custom IPC msg interchange format
|
||||
# dynamic override Bo
|
||||
self._task = trio.lowlevel.current_task()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# self._codec: MsgCodec = (
|
||||
# codec
|
||||
# or
|
||||
# _codec._ctxvar_MsgCodec.get()
|
||||
# )
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||
stream using the current task's `MsgCodec`.
|
||||
|
||||
This is a streaming routine implemented as an async generator
|
||||
func (which was the original design, but could be changed?)
|
||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||
it is assigned to the `._aiter_pkts` attr.
|
||||
|
||||
'''
|
||||
decodes_failed: int = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||
except (
|
||||
ValueError,
|
||||
ConnectionResetError,
|
||||
|
||||
# not sure entirely why we need this but without it we
|
||||
# seem to be getting racy failures here on
|
||||
# arbiter/registry name subs..
|
||||
trio.BrokenResourceError,
|
||||
|
||||
) as trans_err:
|
||||
|
||||
loglevel = 'transport'
|
||||
match trans_err:
|
||||
# case (
|
||||
# ConnectionResetError()
|
||||
# ):
|
||||
# loglevel = 'transport'
|
||||
|
||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||
# seems to raise a 0-bytes-read?
|
||||
case ValueError() if (
|
||||
'unclean EOF' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# peer actor (task) prolly shutdown quickly due
|
||||
# to cancellation
|
||||
case trio.BrokenResourceError() if (
|
||||
'Connection reset by peer' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# unless the disconnect condition falls under "a
|
||||
# normal operation breakage" we usualy console warn
|
||||
# about it.
|
||||
case _:
|
||||
loglevel: str = 'warning'
|
||||
|
||||
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already closed by peer\n'
|
||||
f'x)> {type(trans_err)}\n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel=loglevel,
|
||||
) from trans_err
|
||||
|
||||
# XXX definitely can happen if transport is closed
|
||||
# manually by another `trio.lowlevel.Task` in the
|
||||
# same actor; we use this in some simulated fault
|
||||
# testing for ex, but generally should never happen
|
||||
# under normal operation!
|
||||
#
|
||||
# NOTE: as such we always re-raise this error from the
|
||||
# RPC msg loop!
|
||||
except trio.ClosedResourceError as closure_err:
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already manually closed locally?\n'
|
||||
f'x)> {type(closure_err)} \n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel='error',
|
||||
raise_on_report=(
|
||||
closure_err.args[0] == 'another task closed this fd'
|
||||
or
|
||||
closure_err.args[0] in ['another task closed this fd']
|
||||
),
|
||||
) from closure_err
|
||||
|
||||
# graceful TCP EOF disconnect
|
||||
if header == b'':
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already gracefully closed\n'
|
||||
f')>\n'
|
||||
f'|_{self}\n'
|
||||
),
|
||||
loglevel='transport',
|
||||
# cause=??? # handy or no?
|
||||
)
|
||||
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
|
||||
log.transport(f'received header {size}') # type: ignore
|
||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.transport(f"received {msg_bytes}") # type: ignore
|
||||
try:
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# assert (
|
||||
# task := trio.lowlevel.current_task()
|
||||
# ) is not self._task
|
||||
# self._task = task
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.recv()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg_bytes: {msg_bytes}\n'
|
||||
# )
|
||||
yield codec.decode(msg_bytes)
|
||||
|
||||
# XXX NOTE: since the below error derives from
|
||||
# `DecodeError` we need to catch is specially
|
||||
# and always raise such that spec violations
|
||||
# are never allowed to be caught silently!
|
||||
except msgspec.ValidationError as verr:
|
||||
msgtyperr: MsgTypeError = _mk_recv_mte(
|
||||
msg=msg_bytes,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
)
|
||||
# XXX deliver up to `Channel.recv()` where
|
||||
# a re-raise and `Error`-pack can inject the far
|
||||
# end actor `.uid`.
|
||||
yield msgtyperr
|
||||
|
||||
except (
|
||||
msgspec.DecodeError,
|
||||
UnicodeDecodeError,
|
||||
):
|
||||
if decodes_failed < 4:
|
||||
# ignore decoding errors for now and assume they have to
|
||||
# do with a channel drop - hope that receiving from the
|
||||
# channel will raise an expected error and bubble up.
|
||||
try:
|
||||
msg_str: str|bytes = msg_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
msg_str = msg_bytes
|
||||
|
||||
log.exception(
|
||||
'Failed to decode msg?\n'
|
||||
f'{codec}\n\n'
|
||||
'Rxed bytes from wire:\n\n'
|
||||
f'{msg_str!r}\n'
|
||||
)
|
||||
decodes_failed += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
async def send(
|
||||
self,
|
||||
msg: msgtypes.MsgType,
|
||||
|
||||
strict_types: bool = True,
|
||||
hide_tb: bool = False,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||
|
||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||
invalid msg type
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = hide_tb
|
||||
|
||||
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
||||
# on the `.acquire()`/`.release()` sequencing..
|
||||
async with self._send_lock:
|
||||
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.send()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg: {msg}\n'
|
||||
# )
|
||||
|
||||
if type(msg) not in msgtypes.__msg_types__:
|
||||
if strict_types:
|
||||
raise _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
|
||||
try:
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
except TypeError as _err:
|
||||
typerr = _err
|
||||
msgtyperr: MsgTypeError = _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
message=(
|
||||
f'IPC-msg-spec violation in\n\n'
|
||||
f'{pretty_struct.Struct.pformat(msg)}'
|
||||
),
|
||||
src_type_error=typerr,
|
||||
)
|
||||
raise msgtyperr from typerr
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
size: bytes = struct.pack("<I", len(bytes_data))
|
||||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
# ?TODO? does it help ever to dynamically show this
|
||||
# frame?
|
||||
# try:
|
||||
# <the-above_code>
|
||||
# except BaseException as _err:
|
||||
# err = _err
|
||||
# if not isinstance(err, MsgTypeError):
|
||||
# __tracebackhide__: bool = False
|
||||
# raise
|
||||
|
||||
async def recv(self) -> msgtypes.MsgType:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
except TransportClosed:
|
||||
for msg in self.drained:
|
||||
yield msg
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter_pkts
|
||||
|
||||
@property
|
||||
def laddr(self) -> Address:
|
||||
return self._laddr
|
||||
|
||||
@property
|
||||
def raddr(self) -> Address:
|
||||
return self._raddr
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# tractor: structured concurrent "actors".
|
||||
# Copyright 2018-eternity Tyler Goodlet.
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Type
|
||||
|
||||
import trio
|
||||
import socket
|
||||
|
||||
from tractor._addr import Address
|
||||
from tractor.ipc._transport import (
|
||||
MsgTransportKey,
|
||||
MsgTransport
|
||||
)
|
||||
from tractor.ipc._tcp import MsgpackTCPStream
|
||||
from tractor.ipc._uds import MsgpackUDSStream
|
||||
|
||||
|
||||
# manually updated list of all supported msg transport types
|
||||
_msg_transports = [
|
||||
MsgpackTCPStream,
|
||||
MsgpackUDSStream
|
||||
]
|
||||
|
||||
|
||||
# convert a MsgTransportKey to the corresponding transport type
|
||||
_key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
|
||||
cls.key(): cls
|
||||
for cls in _msg_transports
|
||||
}
|
||||
|
||||
# convert an Address wrapper to its corresponding transport type
|
||||
_addr_to_transport: dict[Type[Address], Type[MsgTransport]] = {
|
||||
cls.address_type: cls
|
||||
for cls in _msg_transports
|
||||
}
|
||||
|
||||
|
||||
def transport_from_addr(
|
||||
addr: Address,
|
||||
codec_key: str = 'msgpack',
|
||||
) -> Type[MsgTransport]:
|
||||
'''
|
||||
Given a destination address and a desired codec, find the
|
||||
corresponding `MsgTransport` type.
|
||||
|
||||
'''
|
||||
try:
|
||||
return _addr_to_transport[type(addr)]
|
||||
|
||||
except KeyError:
|
||||
raise NotImplementedError(
|
||||
f'No known transport for address {repr(addr)}'
|
||||
)
|
||||
|
||||
|
||||
def transport_from_stream(
|
||||
stream: trio.abc.Stream,
|
||||
codec_key: str = 'msgpack'
|
||||
) -> Type[MsgTransport]:
|
||||
'''
|
||||
Given an arbitrary `trio.abc.Stream` and a desired codec,
|
||||
find the corresponding `MsgTransport` type.
|
||||
|
||||
'''
|
||||
transport = None
|
||||
if isinstance(stream, trio.SocketStream):
|
||||
sock = stream.socket
|
||||
match sock.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
transport = 'tcp'
|
||||
|
||||
case socket.AF_UNIX:
|
||||
transport = 'uds'
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported socket family: {sock.family}'
|
||||
)
|
||||
|
||||
if not transport:
|
||||
raise NotImplementedError(
|
||||
f'Could not figure out transport type for stream type {type(stream)}'
|
||||
)
|
||||
|
||||
key = (codec_key, transport)
|
||||
|
||||
return _key_to_transport[key]
|
|
@ -0,0 +1,97 @@
|
|||
# tractor: structured concurrent "actors".
|
||||
# Copyright 2018-eternity Tyler Goodlet.
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import trio
|
||||
|
||||
from tractor.msg import MsgCodec
|
||||
from tractor.log import get_logger
|
||||
from tractor._addr import UDSAddress
|
||||
from tractor.ipc._transport import MsgpackTransport
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class MsgpackUDSStream(MsgpackTransport):
|
||||
'''
|
||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using the ``msgspec`` codec lib.
|
||||
|
||||
'''
|
||||
address_type = UDSAddress
|
||||
layer_key: int = 7
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# stream: trio.SocketStream,
|
||||
# prefix_size: int = 4,
|
||||
# codec: CodecType = None,
|
||||
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# stream,
|
||||
# prefix_size=prefix_size,
|
||||
# codec=codec
|
||||
# )
|
||||
|
||||
@property
|
||||
def maddr(self) -> str:
|
||||
filepath = self.raddr.unwrap()
|
||||
return (
|
||||
f'/ipv4/localhost'
|
||||
f'/{self.address_type.name_key}/{filepath}'
|
||||
# f'/{self.chan.uid[0]}'
|
||||
# f'/{self.cid}'
|
||||
|
||||
# f'/cid={cid_head}..{cid_tail}'
|
||||
# TODO: ? not use this ^ right ?
|
||||
)
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.stream.socket.fileno() != -1
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
addr: UDSAddress,
|
||||
prefix_size: int = 4,
|
||||
codec: MsgCodec|None = None,
|
||||
**kwargs
|
||||
) -> MsgpackUDSStream:
|
||||
stream = await trio.open_unix_socket(
|
||||
addr.unwrap(),
|
||||
**kwargs
|
||||
)
|
||||
return MsgpackUDSStream(
|
||||
stream,
|
||||
prefix_size=prefix_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[UDSAddress, UDSAddress]:
|
||||
return (
|
||||
UDSAddress.from_addr(stream.socket.getsockname()),
|
||||
UDSAddress.from_addr(stream.socket.getsockname()),
|
||||
)
|
|
@ -47,6 +47,7 @@ from tractor.msg import (
|
|||
pretty_struct,
|
||||
)
|
||||
from tractor.log import get_logger
|
||||
from tractor._addr import AddressTypes
|
||||
|
||||
|
||||
log = get_logger('tractor.msgspec')
|
||||
|
@ -167,8 +168,8 @@ class SpawnSpec(
|
|||
|
||||
# TODO: not just sockaddr pairs?
|
||||
# -[ ] abstract into a `TransportAddr` type?
|
||||
reg_addrs: list[tuple[str, int]]
|
||||
bind_addrs: list[tuple[str, int]]
|
||||
reg_addrs: list[AddressTypes]
|
||||
bind_addrs: list[AddressTypes]
|
||||
|
||||
|
||||
# TODO: caps based RPC support in the payload?
|
||||
|
|
Loading…
Reference in New Issue